ericmagalhaes commited on
Commit
ed72c0e
1 Parent(s): 49a817b

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +174 -0
  2. LICENSE +209 -0
  3. __init__.py +1 -0
  4. app.py +287 -0
  5. app_sadtalker.py +111 -0
  6. cog.yaml +35 -0
  7. inference.py +145 -0
  8. launcher.py +204 -0
  9. predict.py +192 -0
  10. quick_demo.ipynb +221 -0
  11. req.txt +22 -0
  12. requirements.txt +30 -0
  13. requirements3d.txt +21 -0
  14. scripts/download_models.sh +15 -0
  15. src/audio2exp_models/audio2exp.py +41 -0
  16. src/audio2exp_models/networks.py +74 -0
  17. src/audio2pose_models/audio2pose.py +94 -0
  18. src/audio2pose_models/audio_encoder.py +64 -0
  19. src/audio2pose_models/cvae.py +149 -0
  20. src/audio2pose_models/discriminator.py +76 -0
  21. src/audio2pose_models/networks.py +140 -0
  22. src/audio2pose_models/res_unet.py +65 -0
  23. src/config/auido2exp.yaml +58 -0
  24. src/config/auido2pose.yaml +49 -0
  25. src/config/facerender.yaml +45 -0
  26. src/config/facerender_still.yaml +45 -0
  27. src/config/similarity_Lm3D_all.mat +0 -0
  28. src/face3d/data/__init__.py +116 -0
  29. src/face3d/data/base_dataset.py +125 -0
  30. src/face3d/data/flist_dataset.py +125 -0
  31. src/face3d/data/image_folder.py +66 -0
  32. src/face3d/data/template_dataset.py +75 -0
  33. src/face3d/extract_kp_videos.py +108 -0
  34. src/face3d/extract_kp_videos_safe.py +151 -0
  35. src/face3d/models/__init__.py +67 -0
  36. src/face3d/models/arcface_torch/README.md +164 -0
  37. src/face3d/models/arcface_torch/backbones/__init__.py +25 -0
  38. src/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
  39. src/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
  40. src/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
  41. src/face3d/models/arcface_torch/configs/3millions.py +23 -0
  42. src/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
  43. src/face3d/models/arcface_torch/configs/__init__.py +0 -0
  44. src/face3d/models/arcface_torch/configs/base.py +56 -0
  45. src/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
  46. src/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
  47. src/face3d/models/arcface_torch/configs/glint360k_r18.py +26 -0
  48. src/face3d/models/arcface_torch/configs/glint360k_r34.py +26 -0
  49. src/face3d/models/arcface_torch/configs/glint360k_r50.py +26 -0
  50. src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +26 -0
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ examples/results/*
163
+ gfpgan/*
164
+ checkpoints/*
165
+ assets/*
166
+ results/*
167
+ Dockerfile
168
+ start_docker.sh
169
+ start.sh
170
+
171
+ checkpoints
172
+
173
+ # Mac
174
+ .DS_Store
LICENSE ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making SadTalker available.
2
+
3
+ Copyright (C), a Tencent company. All rights reserved.
4
+
5
+ SadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below.
6
+
7
+ Terms of the Apache License Version 2.0:
8
+ ---------------------------------------------
9
+ Apache License
10
+ Version 2.0, January 2004
11
+ http://www.apache.org/licenses/
12
+
13
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
14
+
15
+ 1. Definitions.
16
+
17
+ "License" shall mean the terms and conditions for use, reproduction,
18
+ and distribution as defined by Sections 1 through 9 of this document.
19
+
20
+ "Licensor" shall mean the copyright owner or entity authorized by
21
+ the copyright owner that is granting the License.
22
+
23
+ "Legal Entity" shall mean the union of the acting entity and all
24
+ other entities that control, are controlled by, or are under common
25
+ control with that entity. For the purposes of this definition,
26
+ "control" means (i) the power, direct or indirect, to cause the
27
+ direction or management of such entity, whether by contract or
28
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
29
+ outstanding shares, or (iii) beneficial ownership of such entity.
30
+
31
+ "You" (or "Your") shall mean an individual or Legal Entity
32
+ exercising permissions granted by this License.
33
+
34
+ "Source" form shall mean the preferred form for making modifications,
35
+ including but not limited to software source code, documentation
36
+ source, and configuration files.
37
+
38
+ "Object" form shall mean any form resulting from mechanical
39
+ transformation or translation of a Source form, including but
40
+ not limited to compiled object code, generated documentation,
41
+ and conversions to other media types.
42
+
43
+ "Work" shall mean the work of authorship, whether in Source or
44
+ Object form, made available under the License, as indicated by a
45
+ copyright notice that is included in or attached to the work
46
+ (an example is provided in the Appendix below).
47
+
48
+ "Derivative Works" shall mean any work, whether in Source or Object
49
+ form, that is based on (or derived from) the Work and for which the
50
+ editorial revisions, annotations, elaborations, or other modifications
51
+ represent, as a whole, an original work of authorship. For the purposes
52
+ of this License, Derivative Works shall not include works that remain
53
+ separable from, or merely link (or bind by name) to the interfaces of,
54
+ the Work and Derivative Works thereof.
55
+
56
+ "Contribution" shall mean any work of authorship, including
57
+ the original version of the Work and any modifications or additions
58
+ to that Work or Derivative Works thereof, that is intentionally
59
+ submitted to Licensor for inclusion in the Work by the copyright owner
60
+ or by an individual or Legal Entity authorized to submit on behalf of
61
+ the copyright owner. For the purposes of this definition, "submitted"
62
+ means any form of electronic, verbal, or written communication sent
63
+ to the Licensor or its representatives, including but not limited to
64
+ communication on electronic mailing lists, source code control systems,
65
+ and issue tracking systems that are managed by, or on behalf of, the
66
+ Licensor for the purpose of discussing and improving the Work, but
67
+ excluding communication that is conspicuously marked or otherwise
68
+ designated in writing by the copyright owner as "Not a Contribution."
69
+
70
+ "Contributor" shall mean Licensor and any individual or Legal Entity
71
+ on behalf of whom a Contribution has been received by Licensor and
72
+ subsequently incorporated within the Work.
73
+
74
+ 2. Grant of Copyright License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ copyright license to reproduce, prepare Derivative Works of,
78
+ publicly display, publicly perform, sublicense, and distribute the
79
+ Work and such Derivative Works in Source or Object form.
80
+
81
+ 3. Grant of Patent License. Subject to the terms and conditions of
82
+ this License, each Contributor hereby grants to You a perpetual,
83
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
84
+ (except as stated in this section) patent license to make, have made,
85
+ use, offer to sell, sell, import, and otherwise transfer the Work,
86
+ where such license applies only to those patent claims licensable
87
+ by such Contributor that are necessarily infringed by their
88
+ Contribution(s) alone or by combination of their Contribution(s)
89
+ with the Work to which such Contribution(s) was submitted. If You
90
+ institute patent litigation against any entity (including a
91
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
92
+ or a Contribution incorporated within the Work constitutes direct
93
+ or contributory patent infringement, then any patent licenses
94
+ granted to You under this License for that Work shall terminate
95
+ as of the date such litigation is filed.
96
+
97
+ 4. Redistribution. You may reproduce and distribute copies of the
98
+ Work or Derivative Works thereof in any medium, with or without
99
+ modifications, and in Source or Object form, provided that You
100
+ meet the following conditions:
101
+
102
+ (a) You must give any other recipients of the Work or
103
+ Derivative Works a copy of this License; and
104
+
105
+ (b) You must cause any modified files to carry prominent notices
106
+ stating that You changed the files; and
107
+
108
+ (c) You must retain, in the Source form of any Derivative Works
109
+ that You distribute, all copyright, patent, trademark, and
110
+ attribution notices from the Source form of the Work,
111
+ excluding those notices that do not pertain to any part of
112
+ the Derivative Works; and
113
+
114
+ (d) If the Work includes a "NOTICE" text file as part of its
115
+ distribution, then any Derivative Works that You distribute must
116
+ include a readable copy of the attribution notices contained
117
+ within such NOTICE file, excluding those notices that do not
118
+ pertain to any part of the Derivative Works, in at least one
119
+ of the following places: within a NOTICE text file distributed
120
+ as part of the Derivative Works; within the Source form or
121
+ documentation, if provided along with the Derivative Works; or,
122
+ within a display generated by the Derivative Works, if and
123
+ wherever such third-party notices normally appear. The contents
124
+ of the NOTICE file are for informational purposes only and
125
+ do not modify the License. You may add Your own attribution
126
+ notices within Derivative Works that You distribute, alongside
127
+ or as an addendum to the NOTICE text from the Work, provided
128
+ that such additional attribution notices cannot be construed
129
+ as modifying the License.
130
+
131
+ You may add Your own copyright statement to Your modifications and
132
+ may provide additional or different license terms and conditions
133
+ for use, reproduction, or distribution of Your modifications, or
134
+ for any such Derivative Works as a whole, provided Your use,
135
+ reproduction, and distribution of the Work otherwise complies with
136
+ the conditions stated in this License.
137
+
138
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
139
+ any Contribution intentionally submitted for inclusion in the Work
140
+ by You to the Licensor shall be under the terms and conditions of
141
+ this License, without any additional terms or conditions.
142
+ Notwithstanding the above, nothing herein shall supersede or modify
143
+ the terms of any separate license agreement you may have executed
144
+ with Licensor regarding such Contributions.
145
+
146
+ 6. Trademarks. This License does not grant permission to use the trade
147
+ names, trademarks, service marks, or product names of the Licensor,
148
+ except as required for reasonable and customary use in describing the
149
+ origin of the Work and reproducing the content of the NOTICE file.
150
+
151
+ 7. Disclaimer of Warranty. Unless required by applicable law or
152
+ agreed to in writing, Licensor provides the Work (and each
153
+ Contributor provides its Contributions) on an "AS IS" BASIS,
154
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
155
+ implied, including, without limitation, any warranties or conditions
156
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
157
+ PARTICULAR PURPOSE. You are solely responsible for determining the
158
+ appropriateness of using or redistributing the Work and assume any
159
+ risks associated with Your exercise of permissions under this License.
160
+
161
+ 8. Limitation of Liability. In no event and under no legal theory,
162
+ whether in tort (including negligence), contract, or otherwise,
163
+ unless required by applicable law (such as deliberate and grossly
164
+ negligent acts) or agreed to in writing, shall any Contributor be
165
+ liable to You for damages, including any direct, indirect, special,
166
+ incidental, or consequential damages of any character arising as a
167
+ result of this License or out of the use or inability to use the
168
+ Work (including but not limited to damages for loss of goodwill,
169
+ work stoppage, computer failure or malfunction, or any and all
170
+ other commercial damages or losses), even if such Contributor
171
+ has been advised of the possibility of such damages.
172
+
173
+ 9. Accepting Warranty or Additional Liability. While redistributing
174
+ the Work or Derivative Works thereof, You may choose to offer,
175
+ and charge a fee for, acceptance of support, warranty, indemnity,
176
+ or other liability obligations and/or rights consistent with this
177
+ License. However, in accepting such obligations, You may act only
178
+ on Your own behalf and on Your sole responsibility, not on behalf
179
+ of any other Contributor, and only if You agree to indemnify,
180
+ defend, and hold each Contributor harmless for any liability
181
+ incurred by, or claims asserted against, such Contributor by reason
182
+ of your accepting any such warranty or additional liability.
183
+
184
+ END OF TERMS AND CONDITIONS
185
+
186
+ APPENDIX: How to apply the Apache License to your work.
187
+
188
+ To apply the Apache License to your work, attach the following
189
+ boilerplate notice, with the fields enclosed by brackets "[]"
190
+ replaced with your own identifying information. (Don't include
191
+ the brackets!) The text should be enclosed in the appropriate
192
+ comment syntax for the file format. We also recommend that a
193
+ file or class name and description of purpose be included on the
194
+ same "printed page" as the copyright notice for easier
195
+ identification within third-party archives.
196
+
197
+ Copyright [yyyy] [name of copyright owner]
198
+
199
+ Licensed under the Apache License, Version 2.0 (the "License");
200
+ you may not use this file except in compliance with the License.
201
+ You may obtain a copy of the License at
202
+
203
+ http://www.apache.org/licenses/LICENSE-2.0
204
+
205
+ Unless required by applicable law or agreed to in writing, software
206
+ distributed under the License is distributed on an "AS IS" BASIS,
207
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
208
+ See the License for the specific language governing permissions and
209
+ limitations under the License.
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import torch
3
+ import shutil
4
+ import os
5
+ import sys
6
+ from argparse import ArgumentParser
7
+ from time import strftime
8
+ from argparse import Namespace
9
+ from src.utils.preprocess import CropAndExtract
10
+ from src.test_audio2coeff import Audio2Coeff
11
+ from src.facerender.animate import AnimateFromCoeff
12
+ from src.generate_batch import get_data
13
+ from src.generate_facerender_batch import get_facerender_data
14
+ from src.utils.init_path import init_path
15
+ import tempfile
16
+ from openai import OpenAI
17
+ import threading
18
+ import elevenlabs
19
+ from elevenlabs import set_api_key, generate, play, clone
20
+ from flask_cors import CORS, cross_origin
21
+ from flask_swagger_ui import get_swaggerui_blueprint
22
+ import uuid
23
+
24
+ class AnimationConfig:
25
+ def __init__(self, driven_audio_path, source_image_path, result_folder):
26
+ self.driven_audio = driven_audio_path
27
+ self.source_image = source_image_path
28
+ self.ref_eyeblink = None
29
+ self.ref_pose = None
30
+ self.checkpoint_dir = './checkpoints'
31
+ self.result_dir = result_folder
32
+ self.pose_style = 1
33
+ self.batch_size = 1
34
+ self.size = 256
35
+ self.expression_scale = 1
36
+ self.input_yaw = None
37
+ self.input_pitch = None
38
+ self.input_roll = None
39
+ self.enhancer = 'gfpgan'
40
+ self.background_enhancer = None
41
+ self.cpu = False
42
+ self.face3dvis = False
43
+ self.still = False
44
+ self.preprocess = 'crop'
45
+ self.verbose = False
46
+ self.old_version = False
47
+ self.net_recon = 'resnet50'
48
+ self.init_path = None
49
+ self.use_last_fc = False
50
+ self.bfm_folder = './checkpoints/BFM_Fitting/'
51
+ self.bfm_model = 'BFM_model_front.mat'
52
+ self.focal = 1015.
53
+ self.center = 112.
54
+ self.camera_d = 10.
55
+ self.z_near = 5.
56
+ self.z_far = 15.
57
+ self.device = 'cpu'
58
+
59
+ # Define the blueprint
60
+ SWAGGER_URL="/swagger"
61
+ API_URL="/static/swagger.json"
62
+
63
+ swagger_ui_blueprint = get_swaggerui_blueprint(
64
+ SWAGGER_URL,
65
+ API_URL,
66
+ config={
67
+ 'app_name': 'Access API'
68
+ }
69
+ )
70
+
71
+ app = Flask(__name__)
72
+ CORS(app)
73
+ app.register_blueprint(swagger_ui_blueprint, url_prefix=SWAGGER_URL)
74
+
75
+ app.config['temp_response'] = None
76
+ app.config['generation_thread'] = None
77
+
78
+ TEMP_DIR = tempfile.TemporaryDirectory()
79
+
80
+
81
+ def main(args):
82
+ pic_path = args.source_image
83
+ audio_path = args.driven_audio
84
+ save_dir = args.result_dir
85
+ # save_dir = os.path.join(args.result_folder, strftime("%Y_%m_%d_%H.%M.%S"))
86
+ # os.makedirs(save_dir, exist_ok=True)
87
+ print('save_dir',save_dir)
88
+ pose_style = args.pose_style
89
+ device = args.device
90
+ batch_size = args.batch_size
91
+ input_yaw_list = args.input_yaw
92
+ input_pitch_list = args.input_pitch
93
+ input_roll_list = args.input_roll
94
+ ref_eyeblink = args.ref_eyeblink
95
+ ref_pose = args.ref_pose
96
+
97
+ dir_path = os.path.dirname(os.path.realpath(__file__))
98
+ current_root_path = dir_path
99
+ print('current_root_path ',current_root_path)
100
+
101
+ sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
102
+ print('sadtalker_paths ',sadtalker_paths)
103
+
104
+
105
+
106
+ preprocess_model = CropAndExtract(sadtalker_paths, device)
107
+
108
+ audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
109
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
110
+
111
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
112
+ os.makedirs(first_frame_dir, exist_ok=True)
113
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
114
+ source_image_flag=True, pic_size=args.size)
115
+
116
+ print('first_coeff_path ',first_coeff_path)
117
+ print('crop_pic_path ',crop_pic_path)
118
+
119
+
120
+ if first_coeff_path is None:
121
+ print("Can't get the coeffs of the input")
122
+ return
123
+
124
+ if ref_eyeblink is not None:
125
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
126
+ ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
127
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
128
+ ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
129
+ else:
130
+ ref_eyeblink_coeff_path=None
131
+ print('ref_eyeblink_coeff_path',ref_eyeblink_coeff_path)
132
+
133
+ if ref_pose is not None:
134
+ if ref_pose == ref_eyeblink:
135
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
136
+ else:
137
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
138
+ ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
139
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
140
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
141
+ else:
142
+ ref_pose_coeff_path=None
143
+ print('ref_eyeblink_coeff_path',ref_pose_coeff_path)
144
+
145
+ batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
146
+ print('batch',batch)
147
+ coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
148
+
149
+ if args.face3dvis:
150
+ from src.face3d.visualize import gen_composed_video
151
+ gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
152
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
153
+ batch_size, input_yaw_list, input_pitch_list, input_roll_list,
154
+ expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
155
+
156
+ print('data ',data)
157
+ print('save_dir ', save_dir)
158
+ print('pic_path ',pic_path)
159
+ print('crop ',crop_info)
160
+
161
+ result, base64_video = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
162
+ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
163
+
164
+
165
+ print('The generated video is named:')
166
+ app.config['temp_response'] = base64_video
167
+
168
+ return base64_video
169
+
170
+ # shutil.move(result, save_dir+'.mp4')
171
+
172
+
173
+ if not args.verbose:
174
+ shutil.rmtree(save_dir)
175
+
176
+ def save_uploaded_file(file, filename):
177
+ unique_filename = str(uuid.uuid4()) + "_" + filename
178
+ file_path = os.path.join(TEMP_DIR.name, unique_filename)
179
+ file.save(file_path)
180
+ return file_path
181
+
182
+ client = OpenAI(api_key="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")
183
+
184
+ def translate_text(text, target_language):
185
+ response = client.chat.completions.create(
186
+ model="gpt-4-0125-preview",
187
+ messages=[
188
+ {"role": "system", "content": "You are a helpful assistant."},
189
+ {"role": "user", "content": f"Translate the following text into {target_language} Completely: {text}\n"},
190
+ ],
191
+ max_tokens=len(text),
192
+ temperature=0.3,
193
+ )
194
+ return response
195
+
196
+ @app.route("/run", methods=['POST'])
197
+ def generate_video():
198
+ if request.method == 'POST':
199
+ source_image = request.files['source_image']
200
+ text_prompt = request.form['text_prompt']
201
+ voice_cloning = request.form.get('voice_cloning', 'no')
202
+ target_language = request.form.get('target_language', 'English')
203
+
204
+ if target_language != 'English':
205
+ response = translate_text(text_prompt, target_language)
206
+ text_prompt = response.choices[0].message.content.strip()
207
+ print('text_prompt',text_prompt)
208
+
209
+ source_image_path = save_uploaded_file(source_image, 'source_image.png')
210
+ print(source_image_path)
211
+
212
+ if voice_cloning == 'no':
213
+ response = client.audio.speech.create(model="tts-1-hd",
214
+ voice="onyx",
215
+ input = text_prompt)
216
+
217
+ with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_", delete=False) as temp_file:
218
+ driven_audio_path = temp_file.name
219
+
220
+ response.write_to_file(driven_audio_path)
221
+
222
+ elif voice_cloning == 'yes':
223
+ user_voice = request.files['user_voice']
224
+
225
+ with tempfile.NamedTemporaryFile(suffix=".wav", prefix="user_voice_", delete=False) as temp_file:
226
+ user_voice_path = temp_file.name
227
+ user_voice.save(user_voice_path)
228
+ print('user_voice_path',user_voice_path)
229
+
230
+ set_api_key("92e149985ea2732b4359c74346c3daee")
231
+ voice = clone(name = "User Cloned Voice", # Not Required used to store with this name on elevenlabs server
232
+ files = [user_voice_path] )
233
+
234
+ audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2")
235
+ with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_", delete=False) as temp_file:
236
+ driven_audio_path = temp_file.name
237
+ elevenlabs.save(audio, driven_audio_path)
238
+
239
+ save_dir = tempfile.mkdtemp()
240
+ result_folder = os.path.join(save_dir, "results")
241
+ os.makedirs(result_folder, exist_ok=True)
242
+
243
+ # Example of using the class with some hypothetical paths
244
+ args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder)
245
+
246
+ if torch.cuda.is_available() and not args.cpu:
247
+ args.device = "cuda"
248
+ else:
249
+ args.device = "cpu"
250
+
251
+ generation_thread = threading.Thread(target=main, args=(args,))
252
+ app.config['generation_thread'] = generation_thread
253
+ generation_thread.start()
254
+ response_data = {"message": "Video generation started",
255
+ "process_id": generation_thread.ident}
256
+
257
+ return jsonify(response_data)
258
+ # base64_video = main(args)
259
+ # return jsonify({"base64_video": base64_video})
260
+
261
+ #else:
262
+ # return 'Unsupported HTTP method', 405
263
+
264
+ @app.route("/status", methods=["GET"])
265
+ def check_generation_status():
266
+ response = {"base64_video": "", "status": ""}
267
+ process_id = request.args.get('process_id', None)
268
+
269
+ # process_id is required to check the status for that specific process
270
+ if process_id:
271
+ generation_thread = app.config.get('generation_thread')
272
+ if generation_thread and generation_thread.ident == int(process_id) and generation_thread.is_alive():
273
+ return jsonify({"status": "in_progress"}), 200
274
+ elif app.config.get('temp_response'):
275
+ # app.config['temp_response']['status'] = 'completed'
276
+ final_response = app.config['temp_response']
277
+ response["base64_video"] = final_response
278
+ response["status"] = "completed"
279
+ return jsonify(response)
280
+ return jsonify({"error":"No process id provided"})
281
+
282
+ @app.route("/health", methods=["GET"])
283
+ def health_status():
284
+ response = {"online": "true"}
285
+ return jsonify(response)
286
+ if __name__ == '__main__':
287
+ app.run(debug=True)
app_sadtalker.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import gradio as gr
3
+ from src.gradio_demo import SadTalker
4
+
5
+
6
+ try:
7
+ import webui # in webui
8
+ in_webui = True
9
+ except:
10
+ in_webui = False
11
+
12
+
13
+ def toggle_audio_file(choice):
14
+ if choice == False:
15
+ return gr.update(visible=True), gr.update(visible=False)
16
+ else:
17
+ return gr.update(visible=False), gr.update(visible=True)
18
+
19
+ def ref_video_fn(path_of_ref_video):
20
+ if path_of_ref_video is not None:
21
+ return gr.update(value=True)
22
+ else:
23
+ return gr.update(value=False)
24
+
25
+ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):
26
+
27
+ sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)
28
+
29
+ with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
30
+ gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
31
+ <a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
32
+ <a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
33
+ <a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
34
+
35
+ with gr.Row().style(equal_height=False):
36
+ with gr.Column(variant='panel'):
37
+ with gr.Tabs(elem_id="sadtalker_source_image"):
38
+ with gr.TabItem('Upload image'):
39
+ with gr.Row():
40
+ source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
41
+
42
+ with gr.Tabs(elem_id="sadtalker_driven_audio"):
43
+ with gr.TabItem('Upload OR TTS'):
44
+ with gr.Column(variant='panel'):
45
+ driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
46
+
47
+ if sys.platform != 'win32' and not in_webui:
48
+ from src.utils.text2speech import TTSTalker
49
+ tts_talker = TTSTalker()
50
+ with gr.Column(variant='panel'):
51
+ input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
52
+ tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
53
+ tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
54
+
55
+ with gr.Column(variant='panel'):
56
+ with gr.Tabs(elem_id="sadtalker_checkbox"):
57
+ with gr.TabItem('Settings'):
58
+ gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials")
59
+ with gr.Column(variant='panel'):
60
+ # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
61
+ # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
62
+ pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
63
+ size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
64
+ preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
65
+ is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
66
+ batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
67
+ enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
68
+ submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
69
+
70
+ with gr.Tabs(elem_id="sadtalker_genearted"):
71
+ gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
72
+
73
+ if warpfn:
74
+ submit.click(
75
+ fn=warpfn(sad_talker.test),
76
+ inputs=[source_image,
77
+ driven_audio,
78
+ preprocess_type,
79
+ is_still_mode,
80
+ enhancer,
81
+ batch_size,
82
+ size_of_image,
83
+ pose_style
84
+ ],
85
+ outputs=[gen_video]
86
+ )
87
+ else:
88
+ submit.click(
89
+ fn=sad_talker.test,
90
+ inputs=[source_image,
91
+ driven_audio,
92
+ preprocess_type,
93
+ is_still_mode,
94
+ enhancer,
95
+ batch_size,
96
+ size_of_image,
97
+ pose_style
98
+ ],
99
+ outputs=[gen_video]
100
+ )
101
+
102
+ return sadtalker_interface
103
+
104
+
105
+ if __name__ == "__main__":
106
+
107
+ demo = sadtalker_demo()
108
+ demo.queue()
109
+ demo.launch()
110
+
111
+
cog.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.3"
4
+ python_version: "3.8"
5
+ system_packages:
6
+ - "ffmpeg"
7
+ - "libgl1-mesa-glx"
8
+ - "libglib2.0-0"
9
+ python_packages:
10
+ - "torch==1.12.1"
11
+ - "torchvision==0.13.1"
12
+ - "torchaudio==0.12.1"
13
+ - "joblib==1.1.0"
14
+ - "scikit-image==0.19.3"
15
+ - "basicsr==1.4.2"
16
+ - "facexlib==0.3.0"
17
+ - "resampy==0.3.1"
18
+ - "pydub==0.25.1"
19
+ - "scipy==1.10.1"
20
+ - "kornia==0.6.8"
21
+ - "face_alignment==1.3.5"
22
+ - "imageio==2.19.3"
23
+ - "imageio-ffmpeg==0.4.7"
24
+ - "librosa==0.9.2" #
25
+ - "tqdm==4.65.0"
26
+ - "yacs==0.1.8"
27
+ - "gfpgan==1.3.8"
28
+ - "dlib-bin==19.24.1"
29
+ - "av==10.0.0"
30
+ - "trimesh==3.9.20"
31
+ run:
32
+ - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
33
+ - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"
34
+
35
+ predict: "predict.py:Predictor"
inference.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+ import shutil
3
+ import torch
4
+ from time import strftime
5
+ import os, sys, time
6
+ from argparse import ArgumentParser
7
+
8
+ from src.utils.preprocess import CropAndExtract
9
+ from src.test_audio2coeff import Audio2Coeff
10
+ from src.facerender.animate import AnimateFromCoeff
11
+ from src.generate_batch import get_data
12
+ from src.generate_facerender_batch import get_facerender_data
13
+ from src.utils.init_path import init_path
14
+
15
+ def main(args):
16
+ #torch.backends.cudnn.enabled = False
17
+
18
+ pic_path = args.source_image
19
+ audio_path = args.driven_audio
20
+ save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
21
+ os.makedirs(save_dir, exist_ok=True)
22
+ pose_style = args.pose_style
23
+ device = args.device
24
+ batch_size = args.batch_size
25
+ input_yaw_list = args.input_yaw
26
+ input_pitch_list = args.input_pitch
27
+ input_roll_list = args.input_roll
28
+ ref_eyeblink = args.ref_eyeblink
29
+ ref_pose = args.ref_pose
30
+
31
+ current_root_path = os.path.split(sys.argv[0])[0]
32
+
33
+ sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
34
+
35
+ #init model
36
+ preprocess_model = CropAndExtract(sadtalker_paths, device)
37
+
38
+ audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
39
+
40
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
41
+
42
+ #crop image and extract 3dmm from image
43
+ first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
44
+ os.makedirs(first_frame_dir, exist_ok=True)
45
+ print('3DMM Extraction for source image')
46
+ first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
47
+ source_image_flag=True, pic_size=args.size)
48
+ if first_coeff_path is None:
49
+ print("Can't get the coeffs of the input")
50
+ return
51
+
52
+ if ref_eyeblink is not None:
53
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
54
+ ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
55
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
56
+ print('3DMM Extraction for the reference video providing eye blinking')
57
+ ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
58
+ else:
59
+ ref_eyeblink_coeff_path=None
60
+
61
+ if ref_pose is not None:
62
+ if ref_pose == ref_eyeblink:
63
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
64
+ else:
65
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
66
+ ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
67
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
68
+ print('3DMM Extraction for the reference video providing pose')
69
+ ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
70
+ else:
71
+ ref_pose_coeff_path=None
72
+
73
+ #audio2ceoff
74
+ batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
75
+ coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
76
+
77
+ # 3dface render
78
+ if args.face3dvis:
79
+ from src.face3d.visualize import gen_composed_video
80
+ gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
81
+
82
+ #coeff2video
83
+ data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
84
+ batch_size, input_yaw_list, input_pitch_list, input_roll_list,
85
+ expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
86
+
87
+ result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
88
+ enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
89
+
90
+ shutil.move(result, save_dir+'.mp4')
91
+ print('The generated video is named:', save_dir+'.mp4')
92
+
93
+ if not args.verbose:
94
+ shutil.rmtree(save_dir)
95
+
96
+
97
+ if __name__ == '__main__':
98
+
99
+ parser = ArgumentParser()
100
+ parser.add_argument("--driven_audio", default='./examples/driven_audio/voice.wav', help="path to driven audio")
101
+ parser.add_argument("--source_image", default='./examples/source_image/istockphoto-487804668-612x612.png', help="path to source image")
102
+ parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
103
+ parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
104
+ parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
105
+ parser.add_argument("--result_dir", default='./results', help="path to output")
106
+ parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
107
+ parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
108
+ parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
109
+ parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
110
+ parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
111
+ parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
112
+ parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
113
+ parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
114
+ parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
115
+ parser.add_argument("--cpu", dest="cpu", action="store_true")
116
+ parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
117
+ parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
118
+ parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
119
+ parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
120
+ parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
121
+
122
+
123
+ # net structure and parameters
124
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
125
+ parser.add_argument('--init_path', type=str, default=None, help='Useless')
126
+ parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
127
+ parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
128
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
129
+
130
+ # default renderer parameters
131
+ parser.add_argument('--focal', type=float, default=1015.)
132
+ parser.add_argument('--center', type=float, default=112.)
133
+ parser.add_argument('--camera_d', type=float, default=10.)
134
+ parser.add_argument('--z_near', type=float, default=5.)
135
+ parser.add_argument('--z_far', type=float, default=15.)
136
+
137
+ args = parser.parse_args()
138
+
139
+ if torch.cuda.is_available() and not args.cpu:
140
+ args.device = "cuda"
141
+ else:
142
+ args.device = "cpu"
143
+
144
+ main(args)
145
+
launcher.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this scripts installs necessary requirements and launches main program in webui.py
2
+ # borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
3
+ import subprocess
4
+ import os
5
+ import sys
6
+ import importlib.util
7
+ import shlex
8
+ import platform
9
+ import json
10
+
11
+ python = sys.executable
12
+ git = os.environ.get('GIT', "git")
13
+ index_url = os.environ.get('INDEX_URL', "")
14
+ stored_commit_hash = None
15
+ skip_install = False
16
+ dir_repos = "repositories"
17
+ script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
18
+
19
+ if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
20
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
21
+
22
+
23
+ def check_python_version():
24
+ is_windows = platform.system() == "Windows"
25
+ major = sys.version_info.major
26
+ minor = sys.version_info.minor
27
+ micro = sys.version_info.micro
28
+
29
+ if is_windows:
30
+ supported_minors = [10]
31
+ else:
32
+ supported_minors = [7, 8, 9, 10, 11]
33
+
34
+ if not (major == 3 and minor in supported_minors):
35
+
36
+ raise (f"""
37
+ INCOMPATIBLE PYTHON VERSION
38
+ This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
39
+ If you encounter an error with "RuntimeError: Couldn't install torch." message,
40
+ or any other error regarding unsuccessful package (library) installation,
41
+ please downgrade (or upgrade) to the latest version of 3.10 Python
42
+ and delete current Python and "venv" folder in WebUI's directory.
43
+ You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
44
+ {"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
45
+ Use --skip-python-version-check to suppress this warning.
46
+ """)
47
+
48
+
49
+ def commit_hash():
50
+ global stored_commit_hash
51
+
52
+ if stored_commit_hash is not None:
53
+ return stored_commit_hash
54
+
55
+ try:
56
+ stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
57
+ except Exception:
58
+ stored_commit_hash = "<none>"
59
+
60
+ return stored_commit_hash
61
+
62
+
63
+ def run(command, desc=None, errdesc=None, custom_env=None, live=False):
64
+ if desc is not None:
65
+ print(desc)
66
+
67
+ if live:
68
+ result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
69
+ if result.returncode != 0:
70
+ raise RuntimeError(f"""{errdesc or 'Error running command'}.
71
+ Command: {command}
72
+ Error code: {result.returncode}""")
73
+
74
+ return ""
75
+
76
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
77
+
78
+ if result.returncode != 0:
79
+
80
+ message = f"""{errdesc or 'Error running command'}.
81
+ Command: {command}
82
+ Error code: {result.returncode}
83
+ stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
84
+ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
85
+ """
86
+ raise RuntimeError(message)
87
+
88
+ return result.stdout.decode(encoding="utf8", errors="ignore")
89
+
90
+
91
+ def check_run(command):
92
+ result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
93
+ return result.returncode == 0
94
+
95
+
96
+ def is_installed(package):
97
+ try:
98
+ spec = importlib.util.find_spec(package)
99
+ except ModuleNotFoundError:
100
+ return False
101
+
102
+ return spec is not None
103
+
104
+
105
+ def repo_dir(name):
106
+ return os.path.join(script_path, dir_repos, name)
107
+
108
+
109
+ def run_python(code, desc=None, errdesc=None):
110
+ return run(f'"{python}" -c "{code}"', desc, errdesc)
111
+
112
+
113
+ def run_pip(args, desc=None):
114
+ if skip_install:
115
+ return
116
+
117
+ index_url_line = f' --index-url {index_url}' if index_url != '' else ''
118
+ return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
119
+
120
+
121
+ def check_run_python(code):
122
+ return check_run(f'"{python}" -c "{code}"')
123
+
124
+
125
+ def git_clone(url, dir, name, commithash=None):
126
+ # TODO clone into temporary dir and move if successful
127
+
128
+ if os.path.exists(dir):
129
+ if commithash is None:
130
+ return
131
+
132
+ current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
133
+ if current_hash == commithash:
134
+ return
135
+
136
+ run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
137
+ run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
138
+ return
139
+
140
+ run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
141
+
142
+ if commithash is not None:
143
+ run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
144
+
145
+
146
+ def git_pull_recursive(dir):
147
+ for subdir, _, _ in os.walk(dir):
148
+ if os.path.exists(os.path.join(subdir, '.git')):
149
+ try:
150
+ output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
151
+ print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
152
+ except subprocess.CalledProcessError as e:
153
+ print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
154
+
155
+
156
+ def run_extension_installer(extension_dir):
157
+ path_installer = os.path.join(extension_dir, "install.py")
158
+ if not os.path.isfile(path_installer):
159
+ return
160
+
161
+ try:
162
+ env = os.environ.copy()
163
+ env['PYTHONPATH'] = os.path.abspath(".")
164
+
165
+ print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
166
+ except Exception as e:
167
+ print(e, file=sys.stderr)
168
+
169
+
170
+ def prepare_environment():
171
+ global skip_install
172
+
173
+ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
174
+
175
+ ## check windows
176
+ if sys.platform != 'win32':
177
+ requirements_file = os.environ.get('REQS_FILE', "req.txt")
178
+ else:
179
+ requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
180
+
181
+ commit = commit_hash()
182
+
183
+ print(f"Python {sys.version}")
184
+ print(f"Commit hash: {commit}")
185
+
186
+ if not is_installed("torch") or not is_installed("torchvision"):
187
+ run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
188
+
189
+ run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
190
+
191
+ if sys.platform != 'win32' and not is_installed('tts'):
192
+ run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
193
+
194
+
195
+ def start():
196
+ print(f"Launching SadTalker Web UI")
197
+ from app_sadtalker import sadtalker_demo
198
+ demo = sadtalker_demo()
199
+ demo.queue()
200
+ demo.launch()
201
+
202
+ if __name__ == "__main__":
203
+ prepare_environment()
204
+ start()
predict.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """run bash scripts/download_models.sh first to prepare the weights file"""
2
+ import os
3
+ import shutil
4
+ from argparse import Namespace
5
+ from src.utils.preprocess import CropAndExtract
6
+ from src.test_audio2coeff import Audio2Coeff
7
+ from src.facerender.animate import AnimateFromCoeff
8
+ from src.generate_batch import get_data
9
+ from src.generate_facerender_batch import get_facerender_data
10
+ from src.utils.init_path import init_path
11
+ from cog import BasePredictor, Input, Path
12
+
13
+ checkpoints = "checkpoints"
14
+
15
+
16
+ class Predictor(BasePredictor):
17
+ def setup(self):
18
+ """Load the model into memory to make running multiple predictions efficient"""
19
+ device = "cuda"
20
+
21
+
22
+ sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
23
+
24
+ # init model
25
+ self.preprocess_model = CropAndExtract(sadtalker_paths, device
26
+ )
27
+
28
+ self.audio_to_coeff = Audio2Coeff(
29
+ sadtalker_paths,
30
+ device,
31
+ )
32
+
33
+ self.animate_from_coeff = {
34
+ "full": AnimateFromCoeff(
35
+ sadtalker_paths,
36
+ device,
37
+ ),
38
+ "others": AnimateFromCoeff(
39
+ sadtalker_paths,
40
+ device,
41
+ ),
42
+ }
43
+
44
+ def predict(
45
+ self,
46
+ source_image: Path = Input(
47
+ description="Upload the source image, it can be video.mp4 or picture.png",
48
+ ),
49
+ driven_audio: Path = Input(
50
+ description="Upload the driven audio, accepts .wav and .mp4 file",
51
+ ),
52
+ enhancer: str = Input(
53
+ description="Choose a face enhancer",
54
+ choices=["gfpgan", "RestoreFormer"],
55
+ default="gfpgan",
56
+ ),
57
+ preprocess: str = Input(
58
+ description="how to preprocess the images",
59
+ choices=["crop", "resize", "full"],
60
+ default="full",
61
+ ),
62
+ ref_eyeblink: Path = Input(
63
+ description="path to reference video providing eye blinking",
64
+ default=None,
65
+ ),
66
+ ref_pose: Path = Input(
67
+ description="path to reference video providing pose",
68
+ default=None,
69
+ ),
70
+ still: bool = Input(
71
+ description="can crop back to the original videos for the full body aniamtion when preprocess is full",
72
+ default=True,
73
+ ),
74
+ ) -> Path:
75
+ """Run a single prediction on the model"""
76
+
77
+ animate_from_coeff = (
78
+ self.animate_from_coeff["full"]
79
+ if preprocess == "full"
80
+ else self.animate_from_coeff["others"]
81
+ )
82
+
83
+ args = load_default()
84
+ args.pic_path = str(source_image)
85
+ args.audio_path = str(driven_audio)
86
+ device = "cuda"
87
+ args.still = still
88
+ args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
89
+ args.ref_pose = None if ref_pose is None else str(ref_pose)
90
+
91
+ # crop image and extract 3dmm from image
92
+ results_dir = "results"
93
+ if os.path.exists(results_dir):
94
+ shutil.rmtree(results_dir)
95
+ os.makedirs(results_dir)
96
+ first_frame_dir = os.path.join(results_dir, "first_frame_dir")
97
+ os.makedirs(first_frame_dir)
98
+
99
+ print("3DMM Extraction for source image")
100
+ first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
101
+ args.pic_path, first_frame_dir, preprocess, source_image_flag=True
102
+ )
103
+ if first_coeff_path is None:
104
+ print("Can't get the coeffs of the input")
105
+ return
106
+
107
+ if ref_eyeblink is not None:
108
+ ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
109
+ 0
110
+ ]
111
+ ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
112
+ os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
113
+ print("3DMM Extraction for the reference video providing eye blinking")
114
+ ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
115
+ ref_eyeblink, ref_eyeblink_frame_dir
116
+ )
117
+ else:
118
+ ref_eyeblink_coeff_path = None
119
+
120
+ if ref_pose is not None:
121
+ if ref_pose == ref_eyeblink:
122
+ ref_pose_coeff_path = ref_eyeblink_coeff_path
123
+ else:
124
+ ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
125
+ ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
126
+ os.makedirs(ref_pose_frame_dir, exist_ok=True)
127
+ print("3DMM Extraction for the reference video providing pose")
128
+ ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
129
+ ref_pose, ref_pose_frame_dir
130
+ )
131
+ else:
132
+ ref_pose_coeff_path = None
133
+
134
+ # audio2ceoff
135
+ batch = get_data(
136
+ first_coeff_path,
137
+ args.audio_path,
138
+ device,
139
+ ref_eyeblink_coeff_path,
140
+ still=still,
141
+ )
142
+ coeff_path = self.audio_to_coeff.generate(
143
+ batch, results_dir, args.pose_style, ref_pose_coeff_path
144
+ )
145
+ # coeff2video
146
+ print("coeff2video")
147
+ data = get_facerender_data(
148
+ coeff_path,
149
+ crop_pic_path,
150
+ first_coeff_path,
151
+ args.audio_path,
152
+ args.batch_size,
153
+ args.input_yaw,
154
+ args.input_pitch,
155
+ args.input_roll,
156
+ expression_scale=args.expression_scale,
157
+ still_mode=still,
158
+ preprocess=preprocess,
159
+ )
160
+ animate_from_coeff.generate(
161
+ data, results_dir, args.pic_path, crop_info,
162
+ enhancer=enhancer, background_enhancer=args.background_enhancer,
163
+ preprocess=preprocess)
164
+
165
+ output = "/tmp/out.mp4"
166
+ mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
167
+ shutil.copy(mp4_path, output)
168
+
169
+ return Path(output)
170
+
171
+
172
+ def load_default():
173
+ return Namespace(
174
+ pose_style=0,
175
+ batch_size=2,
176
+ expression_scale=1.0,
177
+ input_yaw=None,
178
+ input_pitch=None,
179
+ input_roll=None,
180
+ background_enhancer=None,
181
+ face3dvis=False,
182
+ net_recon="resnet50",
183
+ init_path=None,
184
+ use_last_fc=False,
185
+ bfm_folder="./src/config/",
186
+ bfm_model="BFM_model_front.mat",
187
+ focal=1015.0,
188
+ center=112.0,
189
+ camera_d=10.0,
190
+ z_near=5.0,
191
+ z_far=15.0,
192
+ )
quick_demo.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {
7
+ "id": "M74Gs_TjYl_B"
8
+ },
9
+ "source": [
10
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
11
+ ]
12
+ },
13
+ {
14
+ "attachments": {},
15
+ "cell_type": "markdown",
16
+ "metadata": {
17
+ "id": "view-in-github"
18
+ },
19
+ "source": [
20
+ "### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
21
+ "\n",
22
+ "[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
23
+ "\n",
24
+ "Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
25
+ "\n",
26
+ "Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
27
+ "\n",
28
+ "CVPR 2023\n",
29
+ "\n",
30
+ "TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
31
+ ]
32
+ },
33
+ {
34
+ "attachments": {},
35
+ "cell_type": "markdown",
36
+ "metadata": {
37
+ "id": "kA89DV-sKS4i"
38
+ },
39
+ "source": [
40
+ "Installation (around 5 mins)"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {
47
+ "id": "qJ4CplXsYl_E"
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
52
+ "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {
59
+ "id": "Mdq6j4E5KQAR"
60
+ },
61
+ "outputs": [],
62
+ "source": [
63
+ "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n",
64
+ "!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n",
65
+ "!sudo apt install python3.8\n",
66
+ "\n",
67
+ "!sudo apt-get install python3.8-distutils\n",
68
+ "\n",
69
+ "!python --version\n",
70
+ "\n",
71
+ "!apt-get update\n",
72
+ "\n",
73
+ "!apt install software-properties-common\n",
74
+ "\n",
75
+ "!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
76
+ "\n",
77
+ "!apt-get install python3-pip\n",
78
+ "\n",
79
+ "print('Git clone project and install requirements...')\n",
80
+ "!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
81
+ "%cd SadTalker\n",
82
+ "!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n",
83
+ "!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
84
+ "!apt update\n",
85
+ "!apt install ffmpeg &> /dev/null\n",
86
+ "!python3.8 -m pip install -r requirements.txt"
87
+ ]
88
+ },
89
+ {
90
+ "attachments": {},
91
+ "cell_type": "markdown",
92
+ "metadata": {
93
+ "id": "DddcKB_nKsnk"
94
+ },
95
+ "source": [
96
+ "Download models (1 mins)"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "id": "eDw3_UN8K2xa"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "print('Download pre-trained models...')\n",
108
+ "!rm -rf checkpoints\n",
109
+ "!bash scripts/download_models.sh"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {
116
+ "id": "kK7DYeo7Yl_H"
117
+ },
118
+ "outputs": [],
119
+ "source": [
120
+ "# borrow from makeittalk\n",
121
+ "import ipywidgets as widgets\n",
122
+ "import glob\n",
123
+ "import matplotlib.pyplot as plt\n",
124
+ "print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
125
+ "img_list = glob.glob1('examples/source_image', '*.png')\n",
126
+ "img_list.sort()\n",
127
+ "img_list = [item.split('.')[0] for item in img_list]\n",
128
+ "default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
129
+ "def on_change(change):\n",
130
+ " if change['type'] == 'change' and change['name'] == 'value':\n",
131
+ " plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
132
+ " plt.axis('off')\n",
133
+ " plt.show()\n",
134
+ "default_head_name.observe(on_change)\n",
135
+ "display(default_head_name)\n",
136
+ "plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
137
+ "plt.axis('off')\n",
138
+ "plt.show()"
139
+ ]
140
+ },
141
+ {
142
+ "attachments": {},
143
+ "cell_type": "markdown",
144
+ "metadata": {
145
+ "id": "-khNZcnGK4UK"
146
+ },
147
+ "source": [
148
+ "Animation"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {
155
+ "id": "ToBlDusjK5sS"
156
+ },
157
+ "outputs": [],
158
+ "source": [
159
+ "# selected audio from exmaple/driven_audio\n",
160
+ "img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
161
+ "print(img)\n",
162
+ "!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
163
+ " --source_image {img} \\\n",
164
+ " --result_dir ./results --still --preprocess full --enhancer gfpgan"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {
171
+ "id": "fAjwGmKKYl_I"
172
+ },
173
+ "outputs": [],
174
+ "source": [
175
+ "# visualize code from makeittalk\n",
176
+ "from IPython.display import HTML\n",
177
+ "from base64 import b64encode\n",
178
+ "import os, sys\n",
179
+ "\n",
180
+ "# get the last from results\n",
181
+ "\n",
182
+ "results = sorted(os.listdir('./results/'))\n",
183
+ "\n",
184
+ "mp4_name = glob.glob('./results/*.mp4')[0]\n",
185
+ "\n",
186
+ "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
187
+ "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
188
+ "\n",
189
+ "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
190
+ "display(HTML(\"\"\"\n",
191
+ " <video width=256 controls>\n",
192
+ " <source src=\"%s\" type=\"video/mp4\">\n",
193
+ " </video>\n",
194
+ " \"\"\" % data_url))\n"
195
+ ]
196
+ }
197
+ ],
198
+ "metadata": {
199
+ "accelerator": "GPU",
200
+ "colab": {
201
+ "provenance": []
202
+ },
203
+ "gpuClass": "standard",
204
+ "kernelspec": {
205
+ "display_name": "base",
206
+ "language": "python",
207
+ "name": "python3"
208
+ },
209
+ "language_info": {
210
+ "name": "python",
211
+ "version": "3.9.7"
212
+ },
213
+ "vscode": {
214
+ "interpreter": {
215
+ "hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
216
+ }
217
+ }
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 0
221
+ }
req.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llvmlite==0.38.1
2
+ numpy==1.21.6
3
+ face_alignment==1.3.5
4
+ imageio==2.19.3
5
+ imageio-ffmpeg==0.4.7
6
+ librosa==0.10.0.post2
7
+ numba==0.55.1
8
+ resampy==0.3.1
9
+ pydub==0.25.1
10
+ scipy==1.10.1
11
+ kornia==0.6.8
12
+ tqdm
13
+ yacs==0.1.8
14
+ pyyaml
15
+ joblib==1.1.0
16
+ scikit-image==0.19.3
17
+ basicsr==1.4.2
18
+ facexlib==0.3.0
19
+ gradio
20
+ gfpgan
21
+ av
22
+ safetensors
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.4
2
+ face_alignment==1.3.5
3
+ imageio==2.19.3
4
+ imageio-ffmpeg==0.4.7
5
+ librosa==0.9.2 #
6
+ numba
7
+ resampy==0.3.1
8
+ pydub==0.25.1
9
+ scipy==1.10.1
10
+ kornia==0.6.8
11
+ tqdm
12
+ yacs==0.1.8
13
+ pyyaml
14
+ joblib==1.1.0
15
+ scikit-image==0.19.3
16
+ basicsr==1.4.2
17
+ facexlib==0.3.0
18
+ torchvision==0.12.0
19
+ elevenlabs
20
+ gradio
21
+ gfpgan
22
+ av
23
+ safetensors
24
+ openai
25
+ torch
26
+ moviepy
27
+ flask
28
+ gunicorn
29
+ flask_cors
30
+ flask_swagger_ui
requirements3d.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.4
2
+ face_alignment==1.3.5
3
+ imageio==2.19.3
4
+ imageio-ffmpeg==0.4.7
5
+ librosa==0.9.2 #
6
+ numba
7
+ resampy==0.3.1
8
+ pydub==0.25.1
9
+ scipy==1.5.3
10
+ kornia==0.6.8
11
+ tqdm
12
+ yacs==0.1.8
13
+ pyyaml
14
+ joblib==1.1.0
15
+ scikit-image==0.19.3
16
+ basicsr==1.4.2
17
+ facexlib==0.3.0
18
+ trimesh==3.9.20
19
+ gradio
20
+ gfpgan
21
+ safetensors
scripts/download_models.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Create the checkpoints directory
3
+ mkdir -p ./checkpoints
4
+ # Download model files into the checkpoints directory
5
+ wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar" -O "./checkpoints/mapping_00109-model.pth.tar"
6
+ wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar" -O "./checkpoints/mapping_00229-model.pth.tar"
7
+ wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors" -O "./checkpoints/SadTalker_V0.0.2_256.safetensors"
8
+ wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors" -O "./checkpoints/SadTalker_V0.0.2_512.safetensors"
9
+ # Create the weights directory
10
+ mkdir -p -v "./gfpgan/weights"
11
+ # Download enhancer model files into the weights directory
12
+ wget -nc "https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth" -O "./gfpgan/weights/alignment_WFLW_4HG.pth"
13
+ wget -nc "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" -O "./gfpgan/weights/detection_Resnet50_Final.pth"
14
+ wget -nc "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" -O "./gfpgan/weights/GFPGANv1.4.pth"
15
+ wget -nc "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth" -O "./gfpgan/weights/parsing_parsenet.pth"
src/audio2exp_models/audio2exp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Audio2Exp(nn.Module):
7
+ def __init__(self, netG, cfg, device, prepare_training_loss=False):
8
+ super(Audio2Exp, self).__init__()
9
+ self.cfg = cfg
10
+ self.device = device
11
+ self.netG = netG.to(device)
12
+
13
+ def test(self, batch):
14
+
15
+ mel_input = batch['indiv_mels'] # bs T 1 80 16
16
+ bs = mel_input.shape[0]
17
+ T = mel_input.shape[1]
18
+
19
+ exp_coeff_pred = []
20
+
21
+ for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
22
+
23
+ current_mel_input = mel_input[:,i:i+10]
24
+
25
+ #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
26
+ ref = batch['ref'][:, :, :64][:, i:i+10]
27
+ ratio = batch['ratio_gt'][:, i:i+10] #bs T
28
+
29
+ audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
30
+
31
+ curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
32
+
33
+ exp_coeff_pred += [curr_exp_coeff_pred]
34
+
35
+ # BS x T x 64
36
+ results_dict = {
37
+ 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
38
+ }
39
+ return results_dict
40
+
41
+
src/audio2exp_models/networks.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+ self.use_act = use_act
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+
21
+ if self.use_act:
22
+ return self.act(out)
23
+ else:
24
+ return out
25
+
26
+ class SimpleWrapperV2(nn.Module):
27
+ def __init__(self) -> None:
28
+ super().__init__()
29
+ self.audio_encoder = nn.Sequential(
30
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
31
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
35
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
39
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
40
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
41
+
42
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
43
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
44
+
45
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
46
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
47
+ )
48
+
49
+ #### load the pre-trained audio_encoder
50
+ #self.audio_encoder = self.audio_encoder.to(device)
51
+ '''
52
+ wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
53
+ state_dict = self.audio_encoder.state_dict()
54
+
55
+ for k,v in wav2lip_state_dict.items():
56
+ if 'audio_encoder' in k:
57
+ print('init:', k)
58
+ state_dict[k.replace('module.audio_encoder.', '')] = v
59
+ self.audio_encoder.load_state_dict(state_dict)
60
+ '''
61
+
62
+ self.mapping1 = nn.Linear(512+64+1, 64)
63
+ #self.mapping2 = nn.Linear(30, 64)
64
+ #nn.init.constant_(self.mapping1.weight, 0.)
65
+ nn.init.constant_(self.mapping1.bias, 0.)
66
+
67
+ def forward(self, x, ref, ratio):
68
+ x = self.audio_encoder(x).view(x.size(0), -1)
69
+ ref_reshape = ref.reshape(x.size(0), -1)
70
+ ratio = ratio.reshape(x.size(0), -1)
71
+
72
+ y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
73
+ out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
74
+ return out
src/audio2pose_models/audio2pose.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from src.audio2pose_models.cvae import CVAE
4
+ from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
5
+ from src.audio2pose_models.audio_encoder import AudioEncoder
6
+
7
+ class Audio2Pose(nn.Module):
8
+ def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
9
+ super().__init__()
10
+ self.cfg = cfg
11
+ self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
12
+ self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
13
+ self.device = device
14
+
15
+ self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
16
+ self.audio_encoder.eval()
17
+ for param in self.audio_encoder.parameters():
18
+ param.requires_grad = False
19
+
20
+ self.netG = CVAE(cfg)
21
+ self.netD_motion = PoseSequenceDiscriminator(cfg)
22
+
23
+
24
+ def forward(self, x):
25
+
26
+ batch = {}
27
+ coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
28
+ batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
29
+ batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
30
+ batch['class'] = x['class'].squeeze(0).cuda() # bs
31
+ indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
32
+
33
+ # forward
34
+ audio_emb_list = []
35
+ audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
36
+ batch['audio_emb'] = audio_emb
37
+ batch = self.netG(batch)
38
+
39
+ pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
40
+ pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
41
+ pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
42
+
43
+ batch['pose_pred'] = pose_pred
44
+ batch['pose_gt'] = pose_gt
45
+
46
+ return batch
47
+
48
+ def test(self, x):
49
+
50
+ batch = {}
51
+ ref = x['ref'] #bs 1 70
52
+ batch['ref'] = x['ref'][:,0,-6:]
53
+ batch['class'] = x['class']
54
+ bs = ref.shape[0]
55
+
56
+ indiv_mels= x['indiv_mels'] # bs T 1 80 16
57
+ indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
58
+ num_frames = x['num_frames']
59
+ num_frames = int(num_frames) - 1
60
+
61
+ #
62
+ div = num_frames//self.seq_len
63
+ re = num_frames%self.seq_len
64
+ audio_emb_list = []
65
+ pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
66
+ device=batch['ref'].device)]
67
+
68
+ for i in range(div):
69
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
70
+ batch['z'] = z
71
+ audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
72
+ batch['audio_emb'] = audio_emb
73
+ batch = self.netG.test(batch)
74
+ pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
75
+
76
+ if re != 0:
77
+ z = torch.randn(bs, self.latent_dim).to(ref.device)
78
+ batch['z'] = z
79
+ audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
80
+ if audio_emb.shape[1] != self.seq_len:
81
+ pad_dim = self.seq_len-audio_emb.shape[1]
82
+ pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
83
+ audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
84
+ batch['audio_emb'] = audio_emb
85
+ batch = self.netG.test(batch)
86
+ pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
87
+
88
+ pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
89
+ batch['pose_motion_pred'] = pose_motion_pred
90
+
91
+ pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
92
+
93
+ batch['pose_pred'] = pose_pred
94
+ return batch
src/audio2pose_models/audio_encoder.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Conv2d(nn.Module):
6
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
7
+ super().__init__(*args, **kwargs)
8
+ self.conv_block = nn.Sequential(
9
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
10
+ nn.BatchNorm2d(cout)
11
+ )
12
+ self.act = nn.ReLU()
13
+ self.residual = residual
14
+
15
+ def forward(self, x):
16
+ out = self.conv_block(x)
17
+ if self.residual:
18
+ out += x
19
+ return self.act(out)
20
+
21
+ class AudioEncoder(nn.Module):
22
+ def __init__(self, wav2lip_checkpoint, device):
23
+ super(AudioEncoder, self).__init__()
24
+
25
+ self.audio_encoder = nn.Sequential(
26
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
27
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
28
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
29
+
30
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
31
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
32
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
33
+
34
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
35
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
36
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
37
+
38
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
39
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
40
+
41
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
42
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
43
+
44
+ #### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
45
+ # wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
46
+ # state_dict = self.audio_encoder.state_dict()
47
+
48
+ # for k,v in wav2lip_state_dict.items():
49
+ # if 'audio_encoder' in k:
50
+ # state_dict[k.replace('module.audio_encoder.', '')] = v
51
+ # self.audio_encoder.load_state_dict(state_dict)
52
+
53
+
54
+ def forward(self, audio_sequences):
55
+ # audio_sequences = (B, T, 1, 80, 16)
56
+ B = audio_sequences.size(0)
57
+
58
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
59
+
60
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
61
+ dim = audio_embedding.shape[1]
62
+ audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
63
+
64
+ return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
src/audio2pose_models/cvae.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from src.audio2pose_models.res_unet import ResUnet
5
+
6
+ def class2onehot(idx, class_num):
7
+
8
+ assert torch.max(idx).item() < class_num
9
+ onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
10
+ onehot.scatter_(1, idx, 1)
11
+ return onehot
12
+
13
+ class CVAE(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+ encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
17
+ decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
18
+ latent_size = cfg.MODEL.CVAE.LATENT_SIZE
19
+ num_classes = cfg.DATASET.NUM_CLASSES
20
+ audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
21
+ audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
22
+ seq_len = cfg.MODEL.CVAE.SEQ_LEN
23
+
24
+ self.latent_size = latent_size
25
+
26
+ self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
27
+ audio_emb_in_size, audio_emb_out_size, seq_len)
28
+ self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
29
+ audio_emb_in_size, audio_emb_out_size, seq_len)
30
+ def reparameterize(self, mu, logvar):
31
+ std = torch.exp(0.5 * logvar)
32
+ eps = torch.randn_like(std)
33
+ return mu + eps * std
34
+
35
+ def forward(self, batch):
36
+ batch = self.encoder(batch)
37
+ mu = batch['mu']
38
+ logvar = batch['logvar']
39
+ z = self.reparameterize(mu, logvar)
40
+ batch['z'] = z
41
+ return self.decoder(batch)
42
+
43
+ def test(self, batch):
44
+ '''
45
+ class_id = batch['class']
46
+ z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
47
+ batch['z'] = z
48
+ '''
49
+ return self.decoder(batch)
50
+
51
+ class ENCODER(nn.Module):
52
+ def __init__(self, layer_sizes, latent_size, num_classes,
53
+ audio_emb_in_size, audio_emb_out_size, seq_len):
54
+ super().__init__()
55
+
56
+ self.resunet = ResUnet()
57
+ self.num_classes = num_classes
58
+ self.seq_len = seq_len
59
+
60
+ self.MLP = nn.Sequential()
61
+ layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
62
+ for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
63
+ self.MLP.add_module(
64
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
65
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
66
+
67
+ self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
68
+ self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
69
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
70
+
71
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
72
+
73
+ def forward(self, batch):
74
+ class_id = batch['class']
75
+ pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
76
+ ref = batch['ref'] #bs 6
77
+ bs = pose_motion_gt.shape[0]
78
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
79
+
80
+ #pose encode
81
+ pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
82
+ pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
83
+
84
+ #audio mapping
85
+ print(audio_in.shape)
86
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
87
+ audio_out = audio_out.reshape(bs, -1)
88
+
89
+ class_bias = self.classbias[class_id] #bs latent_size
90
+ x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
91
+ x_out = self.MLP(x_in)
92
+
93
+ mu = self.linear_means(x_out)
94
+ logvar = self.linear_means(x_out) #bs latent_size
95
+
96
+ batch.update({'mu':mu, 'logvar':logvar})
97
+ return batch
98
+
99
+ class DECODER(nn.Module):
100
+ def __init__(self, layer_sizes, latent_size, num_classes,
101
+ audio_emb_in_size, audio_emb_out_size, seq_len):
102
+ super().__init__()
103
+
104
+ self.resunet = ResUnet()
105
+ self.num_classes = num_classes
106
+ self.seq_len = seq_len
107
+
108
+ self.MLP = nn.Sequential()
109
+ input_size = latent_size + seq_len*audio_emb_out_size + 6
110
+ for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
111
+ self.MLP.add_module(
112
+ name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
113
+ if i+1 < len(layer_sizes):
114
+ self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
115
+ else:
116
+ self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
117
+
118
+ self.pose_linear = nn.Linear(6, 6)
119
+ self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
120
+
121
+ self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
122
+
123
+ def forward(self, batch):
124
+
125
+ z = batch['z'] #bs latent_size
126
+ bs = z.shape[0]
127
+ class_id = batch['class']
128
+ ref = batch['ref'] #bs 6
129
+ audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
130
+ #print('audio_in: ', audio_in[:, :, :10])
131
+
132
+ audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
133
+ #print('audio_out: ', audio_out[:, :, :10])
134
+ audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
135
+ class_bias = self.classbias[class_id] #bs latent_size
136
+
137
+ z = z + class_bias
138
+ x_in = torch.cat([ref, z, audio_out], dim=-1)
139
+ x_out = self.MLP(x_in) # bs layer_sizes[-1]
140
+ x_out = x_out.reshape((bs, self.seq_len, -1))
141
+
142
+ #print('x_out: ', x_out)
143
+
144
+ pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
145
+
146
+ pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
147
+
148
+ batch.update({'pose_motion_pred':pose_motion_pred})
149
+ return batch
src/audio2pose_models/discriminator.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ class ConvNormRelu(nn.Module):
6
+ def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
7
+ kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
8
+ super().__init__()
9
+ if kernel_size is None:
10
+ if downsample:
11
+ kernel_size, stride, padding = 4, 2, 1
12
+ else:
13
+ kernel_size, stride, padding = 3, 1, 1
14
+
15
+ if conv_type == '2d':
16
+ self.conv = nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size,
20
+ stride,
21
+ padding,
22
+ bias=False,
23
+ )
24
+ if norm == 'BN':
25
+ self.norm = nn.BatchNorm2d(out_channels)
26
+ elif norm == 'IN':
27
+ self.norm = nn.InstanceNorm2d(out_channels)
28
+ else:
29
+ raise NotImplementedError
30
+ elif conv_type == '1d':
31
+ self.conv = nn.Conv1d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride,
36
+ padding,
37
+ bias=False,
38
+ )
39
+ if norm == 'BN':
40
+ self.norm = nn.BatchNorm1d(out_channels)
41
+ elif norm == 'IN':
42
+ self.norm = nn.InstanceNorm1d(out_channels)
43
+ else:
44
+ raise NotImplementedError
45
+ nn.init.kaiming_normal_(self.conv.weight)
46
+
47
+ self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ if isinstance(self.norm, nn.InstanceNorm1d):
52
+ x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
53
+ else:
54
+ x = self.norm(x)
55
+ x = self.act(x)
56
+ return x
57
+
58
+
59
+ class PoseSequenceDiscriminator(nn.Module):
60
+ def __init__(self, cfg):
61
+ super().__init__()
62
+ self.cfg = cfg
63
+ leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
64
+
65
+ self.seq = nn.Sequential(
66
+ ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
67
+ ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
68
+ ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
69
+ nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
70
+ )
71
+
72
+ def forward(self, x):
73
+ x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
74
+ x = self.seq(x)
75
+ x = x.squeeze(1)
76
+ return x
src/audio2pose_models/networks.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class ResidualConv(nn.Module):
6
+ def __init__(self, input_dim, output_dim, stride, padding):
7
+ super(ResidualConv, self).__init__()
8
+
9
+ self.conv_block = nn.Sequential(
10
+ nn.BatchNorm2d(input_dim),
11
+ nn.ReLU(),
12
+ nn.Conv2d(
13
+ input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
14
+ ),
15
+ nn.BatchNorm2d(output_dim),
16
+ nn.ReLU(),
17
+ nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
18
+ )
19
+ self.conv_skip = nn.Sequential(
20
+ nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
21
+ nn.BatchNorm2d(output_dim),
22
+ )
23
+
24
+ def forward(self, x):
25
+
26
+ return self.conv_block(x) + self.conv_skip(x)
27
+
28
+
29
+ class Upsample(nn.Module):
30
+ def __init__(self, input_dim, output_dim, kernel, stride):
31
+ super(Upsample, self).__init__()
32
+
33
+ self.upsample = nn.ConvTranspose2d(
34
+ input_dim, output_dim, kernel_size=kernel, stride=stride
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.upsample(x)
39
+
40
+
41
+ class Squeeze_Excite_Block(nn.Module):
42
+ def __init__(self, channel, reduction=16):
43
+ super(Squeeze_Excite_Block, self).__init__()
44
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
45
+ self.fc = nn.Sequential(
46
+ nn.Linear(channel, channel // reduction, bias=False),
47
+ nn.ReLU(inplace=True),
48
+ nn.Linear(channel // reduction, channel, bias=False),
49
+ nn.Sigmoid(),
50
+ )
51
+
52
+ def forward(self, x):
53
+ b, c, _, _ = x.size()
54
+ y = self.avg_pool(x).view(b, c)
55
+ y = self.fc(y).view(b, c, 1, 1)
56
+ return x * y.expand_as(x)
57
+
58
+
59
+ class ASPP(nn.Module):
60
+ def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
61
+ super(ASPP, self).__init__()
62
+
63
+ self.aspp_block1 = nn.Sequential(
64
+ nn.Conv2d(
65
+ in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
66
+ ),
67
+ nn.ReLU(inplace=True),
68
+ nn.BatchNorm2d(out_dims),
69
+ )
70
+ self.aspp_block2 = nn.Sequential(
71
+ nn.Conv2d(
72
+ in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
73
+ ),
74
+ nn.ReLU(inplace=True),
75
+ nn.BatchNorm2d(out_dims),
76
+ )
77
+ self.aspp_block3 = nn.Sequential(
78
+ nn.Conv2d(
79
+ in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
80
+ ),
81
+ nn.ReLU(inplace=True),
82
+ nn.BatchNorm2d(out_dims),
83
+ )
84
+
85
+ self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
86
+ self._init_weights()
87
+
88
+ def forward(self, x):
89
+ x1 = self.aspp_block1(x)
90
+ x2 = self.aspp_block2(x)
91
+ x3 = self.aspp_block3(x)
92
+ out = torch.cat([x1, x2, x3], dim=1)
93
+ return self.output(out)
94
+
95
+ def _init_weights(self):
96
+ for m in self.modules():
97
+ if isinstance(m, nn.Conv2d):
98
+ nn.init.kaiming_normal_(m.weight)
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class Upsample_(nn.Module):
105
+ def __init__(self, scale=2):
106
+ super(Upsample_, self).__init__()
107
+
108
+ self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
109
+
110
+ def forward(self, x):
111
+ return self.upsample(x)
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+ def __init__(self, input_encoder, input_decoder, output_dim):
116
+ super(AttentionBlock, self).__init__()
117
+
118
+ self.conv_encoder = nn.Sequential(
119
+ nn.BatchNorm2d(input_encoder),
120
+ nn.ReLU(),
121
+ nn.Conv2d(input_encoder, output_dim, 3, padding=1),
122
+ nn.MaxPool2d(2, 2),
123
+ )
124
+
125
+ self.conv_decoder = nn.Sequential(
126
+ nn.BatchNorm2d(input_decoder),
127
+ nn.ReLU(),
128
+ nn.Conv2d(input_decoder, output_dim, 3, padding=1),
129
+ )
130
+
131
+ self.conv_attn = nn.Sequential(
132
+ nn.BatchNorm2d(output_dim),
133
+ nn.ReLU(),
134
+ nn.Conv2d(output_dim, 1, 1),
135
+ )
136
+
137
+ def forward(self, x1, x2):
138
+ out = self.conv_encoder(x1) + self.conv_decoder(x2)
139
+ out = self.conv_attn(out)
140
+ return out * x2
src/audio2pose_models/res_unet.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.audio2pose_models.networks import ResidualConv, Upsample
4
+
5
+
6
+ class ResUnet(nn.Module):
7
+ def __init__(self, channel=1, filters=[32, 64, 128, 256]):
8
+ super(ResUnet, self).__init__()
9
+
10
+ self.input_layer = nn.Sequential(
11
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(filters[0]),
13
+ nn.ReLU(),
14
+ nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
15
+ )
16
+ self.input_skip = nn.Sequential(
17
+ nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
18
+ )
19
+
20
+ self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
21
+ self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
22
+
23
+ self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
24
+
25
+ self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
26
+ self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
27
+
28
+ self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
29
+ self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
30
+
31
+ self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
32
+ self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
33
+
34
+ self.output_layer = nn.Sequential(
35
+ nn.Conv2d(filters[0], 1, 1, 1),
36
+ nn.Sigmoid(),
37
+ )
38
+
39
+ def forward(self, x):
40
+ # Encode
41
+ x1 = self.input_layer(x) + self.input_skip(x)
42
+ x2 = self.residual_conv_1(x1)
43
+ x3 = self.residual_conv_2(x2)
44
+ # Bridge
45
+ x4 = self.bridge(x3)
46
+
47
+ # Decode
48
+ x4 = self.upsample_1(x4)
49
+ x5 = torch.cat([x4, x3], dim=1)
50
+
51
+ x6 = self.up_residual_conv1(x5)
52
+
53
+ x6 = self.upsample_2(x6)
54
+ x7 = torch.cat([x6, x2], dim=1)
55
+
56
+ x8 = self.up_residual_conv2(x7)
57
+
58
+ x8 = self.upsample_3(x8)
59
+ x9 = torch.cat([x8, x1], dim=1)
60
+
61
+ x10 = self.up_residual_conv3(x9)
62
+
63
+ output = self.output_layer(x10)
64
+
65
+ return output
src/config/auido2exp.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
4
+ TRAIN_BATCH_SIZE: 32
5
+ EVAL_BATCH_SIZE: 32
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
13
+ LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
14
+ DEBUG: True
15
+ NUM_REPEATS: 2
16
+ T: 40
17
+
18
+
19
+ MODEL:
20
+ FRAMEWORK: V2
21
+ AUDIOENCODER:
22
+ LEAKY_RELU: True
23
+ NORM: 'IN'
24
+ DISCRIMINATOR:
25
+ LEAKY_RELU: False
26
+ INPUT_CHANNELS: 6
27
+ CVAE:
28
+ AUDIO_EMB_IN_SIZE: 512
29
+ AUDIO_EMB_OUT_SIZE: 128
30
+ SEQ_LEN: 32
31
+ LATENT_SIZE: 256
32
+ ENCODER_LAYER_SIZES: [192, 1024]
33
+ DECODER_LAYER_SIZES: [1024, 192]
34
+
35
+
36
+ TRAIN:
37
+ MAX_EPOCH: 300
38
+ GENERATOR:
39
+ LR: 2.0e-5
40
+ DISCRIMINATOR:
41
+ LR: 1.0e-5
42
+ LOSS:
43
+ W_FEAT: 0
44
+ W_COEFF_EXP: 2
45
+ W_LM: 1.0e-2
46
+ W_LM_MOUTH: 0
47
+ W_REG: 0
48
+ W_SYNC: 0
49
+ W_COLOR: 0
50
+ W_EXPRESSION: 0
51
+ W_LIPREADING: 0.01
52
+ W_LIPREADING_VV: 0
53
+ W_EYE_BLINK: 4
54
+
55
+ TAG:
56
+ NAME: small_dataset
57
+
58
+
src/config/auido2pose.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
3
+ EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
4
+ TRAIN_BATCH_SIZE: 64
5
+ EVAL_BATCH_SIZE: 1
6
+ EXP: True
7
+ EXP_DIM: 64
8
+ FRAME_LEN: 32
9
+ COEFF_LEN: 73
10
+ NUM_CLASSES: 46
11
+ AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
12
+ COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
13
+ DEBUG: True
14
+
15
+
16
+ MODEL:
17
+ AUDIOENCODER:
18
+ LEAKY_RELU: True
19
+ NORM: 'IN'
20
+ DISCRIMINATOR:
21
+ LEAKY_RELU: False
22
+ INPUT_CHANNELS: 6
23
+ CVAE:
24
+ AUDIO_EMB_IN_SIZE: 512
25
+ AUDIO_EMB_OUT_SIZE: 6
26
+ SEQ_LEN: 32
27
+ LATENT_SIZE: 64
28
+ ENCODER_LAYER_SIZES: [192, 128]
29
+ DECODER_LAYER_SIZES: [128, 192]
30
+
31
+
32
+ TRAIN:
33
+ MAX_EPOCH: 150
34
+ GENERATOR:
35
+ LR: 1.0e-4
36
+ DISCRIMINATOR:
37
+ LR: 1.0e-4
38
+ LOSS:
39
+ LAMBDA_REG: 1
40
+ LAMBDA_LANDMARKS: 0
41
+ LAMBDA_VERTICES: 0
42
+ LAMBDA_GAN_MOTION: 0.7
43
+ LAMBDA_GAN_COEFF: 0
44
+ LAMBDA_KL: 1
45
+
46
+ TAG:
47
+ NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
48
+
49
+
src/config/facerender.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 70
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
src/config/facerender_still.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_params:
2
+ common_params:
3
+ num_kp: 15
4
+ image_channel: 3
5
+ feature_channel: 32
6
+ estimate_jacobian: False # True
7
+ kp_detector_params:
8
+ temperature: 0.1
9
+ block_expansion: 32
10
+ max_features: 1024
11
+ scale_factor: 0.25 # 0.25
12
+ num_blocks: 5
13
+ reshape_channel: 16384 # 16384 = 1024 * 16
14
+ reshape_depth: 16
15
+ he_estimator_params:
16
+ block_expansion: 64
17
+ max_features: 2048
18
+ num_bins: 66
19
+ generator_params:
20
+ block_expansion: 64
21
+ max_features: 512
22
+ num_down_blocks: 2
23
+ reshape_channel: 32
24
+ reshape_depth: 16 # 512 = 32 * 16
25
+ num_resblocks: 6
26
+ estimate_occlusion_map: True
27
+ dense_motion_params:
28
+ block_expansion: 32
29
+ max_features: 1024
30
+ num_blocks: 5
31
+ reshape_depth: 16
32
+ compress: 4
33
+ discriminator_params:
34
+ scales: [1]
35
+ block_expansion: 32
36
+ max_features: 512
37
+ num_blocks: 4
38
+ sn: True
39
+ mapping_params:
40
+ coeff_nc: 73
41
+ descriptor_nc: 1024
42
+ layer: 3
43
+ num_kp: 15
44
+ num_bins: 66
45
+
src/config/similarity_Lm3D_all.mat ADDED
Binary file (994 Bytes). View file
 
src/face3d/data/__init__.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package includes all the modules related to data loading and preprocessing
2
+
3
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4
+ You need to implement four functions:
5
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6
+ -- <__len__>: return the size of dataset.
7
+ -- <__getitem__>: get a data point from data loader.
8
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
9
+
10
+ Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11
+ See our template dataset class 'template_dataset.py' for more details.
12
+ """
13
+ import numpy as np
14
+ import importlib
15
+ import torch.utils.data
16
+ from face3d.data.base_dataset import BaseDataset
17
+
18
+
19
+ def find_dataset_using_name(dataset_name):
20
+ """Import the module "data/[dataset_name]_dataset.py".
21
+
22
+ In the file, the class called DatasetNameDataset() will
23
+ be instantiated. It has to be a subclass of BaseDataset,
24
+ and it is case-insensitive.
25
+ """
26
+ dataset_filename = "data." + dataset_name + "_dataset"
27
+ datasetlib = importlib.import_module(dataset_filename)
28
+
29
+ dataset = None
30
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31
+ for name, cls in datasetlib.__dict__.items():
32
+ if name.lower() == target_dataset_name.lower() \
33
+ and issubclass(cls, BaseDataset):
34
+ dataset = cls
35
+
36
+ if dataset is None:
37
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38
+
39
+ return dataset
40
+
41
+
42
+ def get_option_setter(dataset_name):
43
+ """Return the static method <modify_commandline_options> of the dataset class."""
44
+ dataset_class = find_dataset_using_name(dataset_name)
45
+ return dataset_class.modify_commandline_options
46
+
47
+
48
+ def create_dataset(opt, rank=0):
49
+ """Create a dataset given the option.
50
+
51
+ This function wraps the class CustomDatasetDataLoader.
52
+ This is the main interface between this package and 'train.py'/'test.py'
53
+
54
+ Example:
55
+ >>> from data import create_dataset
56
+ >>> dataset = create_dataset(opt)
57
+ """
58
+ data_loader = CustomDatasetDataLoader(opt, rank=rank)
59
+ dataset = data_loader.load_data()
60
+ return dataset
61
+
62
+ class CustomDatasetDataLoader():
63
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
64
+
65
+ def __init__(self, opt, rank=0):
66
+ """Initialize this class
67
+
68
+ Step 1: create a dataset instance given the name [dataset_mode]
69
+ Step 2: create a multi-threaded data loader.
70
+ """
71
+ self.opt = opt
72
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
73
+ self.dataset = dataset_class(opt)
74
+ self.sampler = None
75
+ print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76
+ if opt.use_ddp and opt.isTrain:
77
+ world_size = opt.world_size
78
+ self.sampler = torch.utils.data.distributed.DistributedSampler(
79
+ self.dataset,
80
+ num_replicas=world_size,
81
+ rank=rank,
82
+ shuffle=not opt.serial_batches
83
+ )
84
+ self.dataloader = torch.utils.data.DataLoader(
85
+ self.dataset,
86
+ sampler=self.sampler,
87
+ num_workers=int(opt.num_threads / world_size),
88
+ batch_size=int(opt.batch_size / world_size),
89
+ drop_last=True)
90
+ else:
91
+ self.dataloader = torch.utils.data.DataLoader(
92
+ self.dataset,
93
+ batch_size=opt.batch_size,
94
+ shuffle=(not opt.serial_batches) and opt.isTrain,
95
+ num_workers=int(opt.num_threads),
96
+ drop_last=True
97
+ )
98
+
99
+ def set_epoch(self, epoch):
100
+ self.dataset.current_epoch = epoch
101
+ if self.sampler is not None:
102
+ self.sampler.set_epoch(epoch)
103
+
104
+ def load_data(self):
105
+ return self
106
+
107
+ def __len__(self):
108
+ """Return the number of data in the dataset"""
109
+ return min(len(self.dataset), self.opt.max_dataset_size)
110
+
111
+ def __iter__(self):
112
+ """Return a batch of data"""
113
+ for i, data in enumerate(self.dataloader):
114
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
115
+ break
116
+ yield data
src/face3d/data/base_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2
+
3
+ It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4
+ """
5
+ import random
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+ from PIL import Image
9
+ import torchvision.transforms as transforms
10
+ from abc import ABC, abstractmethod
11
+
12
+
13
+ class BaseDataset(data.Dataset, ABC):
14
+ """This class is an abstract base class (ABC) for datasets.
15
+
16
+ To create a subclass, you need to implement the following four functions:
17
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18
+ -- <__len__>: return the size of dataset.
19
+ -- <__getitem__>: get a data point.
20
+ -- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the class; save the options in the class
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ self.opt = opt
30
+ # self.root = opt.dataroot
31
+ self.current_epoch = 0
32
+
33
+ @staticmethod
34
+ def modify_commandline_options(parser, is_train):
35
+ """Add new dataset-specific options, and rewrite default values for existing options.
36
+
37
+ Parameters:
38
+ parser -- original option parser
39
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40
+
41
+ Returns:
42
+ the modified parser.
43
+ """
44
+ return parser
45
+
46
+ @abstractmethod
47
+ def __len__(self):
48
+ """Return the total number of images in the dataset."""
49
+ return 0
50
+
51
+ @abstractmethod
52
+ def __getitem__(self, index):
53
+ """Return a data point and its metadata information.
54
+
55
+ Parameters:
56
+ index - - a random integer for data indexing
57
+
58
+ Returns:
59
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60
+ """
61
+ pass
62
+
63
+
64
+ def get_transform(grayscale=False):
65
+ transform_list = []
66
+ if grayscale:
67
+ transform_list.append(transforms.Grayscale(1))
68
+ transform_list += [transforms.ToTensor()]
69
+ return transforms.Compose(transform_list)
70
+
71
+ def get_affine_mat(opt, size):
72
+ shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73
+ w, h = size
74
+
75
+ if 'shift' in opt.preprocess:
76
+ shift_pixs = int(opt.shift_pixs)
77
+ shift_x = random.randint(-shift_pixs, shift_pixs)
78
+ shift_y = random.randint(-shift_pixs, shift_pixs)
79
+ if 'scale' in opt.preprocess:
80
+ scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81
+ if 'rot' in opt.preprocess:
82
+ rot_angle = opt.rot_angle * (2 * random.random() - 1)
83
+ rot_rad = -rot_angle * np.pi/180
84
+ if 'flip' in opt.preprocess:
85
+ flip = random.random() > 0.5
86
+
87
+ shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88
+ flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89
+ shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90
+ rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91
+ scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92
+ shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93
+
94
+ affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95
+ affine_inv = np.linalg.inv(affine)
96
+ return affine, affine_inv, flip
97
+
98
+ def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
99
+ return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
100
+
101
+ def apply_lm_affine(landmark, affine, flip, size):
102
+ _, h = size
103
+ lm = landmark.copy()
104
+ lm[:, 1] = h - 1 - lm[:, 1]
105
+ lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106
+ lm = lm @ np.transpose(affine)
107
+ lm[:, :2] = lm[:, :2] / lm[:, 2:]
108
+ lm = lm[:, :2]
109
+ lm[:, 1] = h - 1 - lm[:, 1]
110
+ if flip:
111
+ lm_ = lm.copy()
112
+ lm_[:17] = lm[16::-1]
113
+ lm_[17:22] = lm[26:21:-1]
114
+ lm_[22:27] = lm[21:16:-1]
115
+ lm_[31:36] = lm[35:30:-1]
116
+ lm_[36:40] = lm[45:41:-1]
117
+ lm_[40:42] = lm[47:45:-1]
118
+ lm_[42:46] = lm[39:35:-1]
119
+ lm_[46:48] = lm[41:39:-1]
120
+ lm_[48:55] = lm[54:47:-1]
121
+ lm_[55:60] = lm[59:54:-1]
122
+ lm_[60:65] = lm[64:59:-1]
123
+ lm_[65:68] = lm[67:64:-1]
124
+ lm = lm_
125
+ return lm
src/face3d/data/flist_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2
+ """
3
+
4
+ import os.path
5
+ from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6
+ from data.image_folder import make_dataset
7
+ from PIL import Image
8
+ import random
9
+ import util.util as util
10
+ import numpy as np
11
+ import json
12
+ import torch
13
+ from scipy.io import loadmat, savemat
14
+ import pickle
15
+ from util.preprocess import align_img, estimate_norm
16
+ from util.load_mats import load_lm3d
17
+
18
+
19
+ def default_flist_reader(flist):
20
+ """
21
+ flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22
+ """
23
+ imlist = []
24
+ with open(flist, 'r') as rf:
25
+ for line in rf.readlines():
26
+ impath = line.strip()
27
+ imlist.append(impath)
28
+
29
+ return imlist
30
+
31
+ def jason_flist_reader(flist):
32
+ with open(flist, 'r') as fp:
33
+ info = json.load(fp)
34
+ return info
35
+
36
+ def parse_label(label):
37
+ return torch.tensor(np.array(label).astype(np.float32))
38
+
39
+
40
+ class FlistDataset(BaseDataset):
41
+ """
42
+ It requires one directories to host training images '/path/to/data/train'
43
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
44
+ """
45
+
46
+ def __init__(self, opt):
47
+ """Initialize this dataset class.
48
+
49
+ Parameters:
50
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51
+ """
52
+ BaseDataset.__init__(self, opt)
53
+
54
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
55
+
56
+ msk_names = default_flist_reader(opt.flist)
57
+ self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58
+
59
+ self.size = len(self.msk_paths)
60
+ self.opt = opt
61
+
62
+ self.name = 'train' if opt.isTrain else 'val'
63
+ if '_' in opt.flist:
64
+ self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65
+
66
+
67
+ def __getitem__(self, index):
68
+ """Return a data point and its metadata information.
69
+
70
+ Parameters:
71
+ index (int) -- a random integer for data indexing
72
+
73
+ Returns a dictionary that contains A, B, A_paths and B_paths
74
+ img (tensor) -- an image in the input domain
75
+ msk (tensor) -- its corresponding attention mask
76
+ lm (tensor) -- its corresponding 3d landmarks
77
+ im_paths (str) -- image paths
78
+ aug_flag (bool) -- a flag used to tell whether its raw or augmented
79
+ """
80
+ msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81
+ img_path = msk_path.replace('mask/', '')
82
+ lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83
+
84
+ raw_img = Image.open(img_path).convert('RGB')
85
+ raw_msk = Image.open(msk_path).convert('RGB')
86
+ raw_lm = np.loadtxt(lm_path).astype(np.float32)
87
+
88
+ _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89
+
90
+ aug_flag = self.opt.use_aug and self.opt.isTrain
91
+ if aug_flag:
92
+ img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93
+
94
+ _, H = img.size
95
+ M = estimate_norm(lm, H)
96
+ transform = get_transform()
97
+ img_tensor = transform(img)
98
+ msk_tensor = transform(msk)[:1, ...]
99
+ lm_tensor = parse_label(lm)
100
+ M_tensor = parse_label(M)
101
+
102
+
103
+ return {'imgs': img_tensor,
104
+ 'lms': lm_tensor,
105
+ 'msks': msk_tensor,
106
+ 'M': M_tensor,
107
+ 'im_paths': img_path,
108
+ 'aug_flag': aug_flag,
109
+ 'dataset': self.name}
110
+
111
+ def _augmentation(self, img, lm, opt, msk=None):
112
+ affine, affine_inv, flip = get_affine_mat(opt, img.size)
113
+ img = apply_img_affine(img, affine_inv)
114
+ lm = apply_lm_affine(lm, affine, flip, img.size)
115
+ if msk is not None:
116
+ msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117
+ return img, lm, msk
118
+
119
+
120
+
121
+
122
+ def __len__(self):
123
+ """Return the total number of images in the dataset.
124
+ """
125
+ return self.size
src/face3d/data/image_folder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A modified image folder class
2
+
3
+ We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4
+ so that this class can load images from both current directory and its subdirectories.
5
+ """
6
+ import numpy as np
7
+ import torch.utils.data as data
8
+
9
+ from PIL import Image
10
+ import os
11
+ import os.path
12
+
13
+ IMG_EXTENSIONS = [
14
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
15
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16
+ '.tif', '.TIF', '.tiff', '.TIFF',
17
+ ]
18
+
19
+
20
+ def is_image_file(filename):
21
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22
+
23
+
24
+ def make_dataset(dir, max_dataset_size=float("inf")):
25
+ images = []
26
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27
+
28
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29
+ for fname in fnames:
30
+ if is_image_file(fname):
31
+ path = os.path.join(root, fname)
32
+ images.append(path)
33
+ return images[:min(max_dataset_size, len(images))]
34
+
35
+
36
+ def default_loader(path):
37
+ return Image.open(path).convert('RGB')
38
+
39
+
40
+ class ImageFolder(data.Dataset):
41
+
42
+ def __init__(self, root, transform=None, return_paths=False,
43
+ loader=default_loader):
44
+ imgs = make_dataset(root)
45
+ if len(imgs) == 0:
46
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
47
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48
+
49
+ self.root = root
50
+ self.imgs = imgs
51
+ self.transform = transform
52
+ self.return_paths = return_paths
53
+ self.loader = loader
54
+
55
+ def __getitem__(self, index):
56
+ path = self.imgs[index]
57
+ img = self.loader(path)
58
+ if self.transform is not None:
59
+ img = self.transform(img)
60
+ if self.return_paths:
61
+ return img, path
62
+ else:
63
+ return img
64
+
65
+ def __len__(self):
66
+ return len(self.imgs)
src/face3d/data/template_dataset.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset class template
2
+
3
+ This module provides a template for users to implement custom datasets.
4
+ You can specify '--dataset_mode template' to use this dataset.
5
+ The class name should be consistent with both the filename and its dataset_mode option.
6
+ The filename should be <dataset_mode>_dataset.py
7
+ The class name should be <Dataset_mode>Dataset.py
8
+ You need to implement the following functions:
9
+ -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
10
+ -- <__init__>: Initialize this dataset class.
11
+ -- <__getitem__>: Return a data point and its metadata information.
12
+ -- <__len__>: Return the number of images.
13
+ """
14
+ from data.base_dataset import BaseDataset, get_transform
15
+ # from data.image_folder import make_dataset
16
+ # from PIL import Image
17
+
18
+
19
+ class TemplateDataset(BaseDataset):
20
+ """A template dataset class for you to implement custom datasets."""
21
+ @staticmethod
22
+ def modify_commandline_options(parser, is_train):
23
+ """Add new dataset-specific options, and rewrite default values for existing options.
24
+
25
+ Parameters:
26
+ parser -- original option parser
27
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28
+
29
+ Returns:
30
+ the modified parser.
31
+ """
32
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34
+ return parser
35
+
36
+ def __init__(self, opt):
37
+ """Initialize this dataset class.
38
+
39
+ Parameters:
40
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41
+
42
+ A few things can be done here.
43
+ - save the options (have been done in BaseDataset)
44
+ - get image paths and meta information of the dataset.
45
+ - define the image transformation.
46
+ """
47
+ # save the option and dataset root
48
+ BaseDataset.__init__(self, opt)
49
+ # get the image paths of your dataset;
50
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51
+ # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
52
+ self.transform = get_transform(opt)
53
+
54
+ def __getitem__(self, index):
55
+ """Return a data point and its metadata information.
56
+
57
+ Parameters:
58
+ index -- a random integer for data indexing
59
+
60
+ Returns:
61
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
62
+
63
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
64
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66
+ Step 4: return a data point as a dictionary.
67
+ """
68
+ path = 'temp' # needs to be a string
69
+ data_A = None # needs to be a tensor
70
+ data_B = None # needs to be a tensor
71
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
72
+
73
+ def __len__(self):
74
+ """Return the total number of images."""
75
+ return len(self.image_paths)
src/face3d/extract_kp_videos.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import face_alignment
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
+
14
+ class KeypointExtractor():
15
+ def __init__(self, device):
16
+ self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
17
+ device=device)
18
+
19
+ def extract_keypoint(self, images, name=None, info=True):
20
+ if isinstance(images, list):
21
+ keypoints = []
22
+ if info:
23
+ i_range = tqdm(images,desc='landmark Det:')
24
+ else:
25
+ i_range = images
26
+
27
+ for image in i_range:
28
+ current_kp = self.extract_keypoint(image)
29
+ if np.mean(current_kp) == -1 and keypoints:
30
+ keypoints.append(keypoints[-1])
31
+ else:
32
+ keypoints.append(current_kp[None])
33
+
34
+ keypoints = np.concatenate(keypoints, 0)
35
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
36
+ return keypoints
37
+ else:
38
+ while True:
39
+ try:
40
+ keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
41
+ break
42
+ except RuntimeError as e:
43
+ if str(e).startswith('CUDA'):
44
+ print("Warning: out of memory, sleep for 1s")
45
+ time.sleep(1)
46
+ else:
47
+ print(e)
48
+ break
49
+ except TypeError:
50
+ print('No face detected in this image')
51
+ shape = [68, 2]
52
+ keypoints = -1. * np.ones(shape)
53
+ break
54
+ if name is not None:
55
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
56
+ return keypoints
57
+
58
+ def read_video(filename):
59
+ frames = []
60
+ cap = cv2.VideoCapture(filename)
61
+ while cap.isOpened():
62
+ ret, frame = cap.read()
63
+ if ret:
64
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
+ frame = Image.fromarray(frame)
66
+ frames.append(frame)
67
+ else:
68
+ break
69
+ cap.release()
70
+ return frames
71
+
72
+ def run(data):
73
+ filename, opt, device = data
74
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
75
+ kp_extractor = KeypointExtractor()
76
+ images = read_video(filename)
77
+ name = filename.split('/')[-2:]
78
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
79
+ kp_extractor.extract_keypoint(
80
+ images,
81
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
82
+ )
83
+
84
+ if __name__ == '__main__':
85
+ set_start_method('spawn')
86
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
87
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
88
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
89
+ parser.add_argument('--device_ids', type=str, default='0,1')
90
+ parser.add_argument('--workers', type=int, default=4)
91
+
92
+ opt = parser.parse_args()
93
+ filenames = list()
94
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
95
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
96
+ extensions = VIDEO_EXTENSIONS
97
+
98
+ for ext in extensions:
99
+ os.listdir(f'{opt.input_dir}')
100
+ print(f'{opt.input_dir}/*.{ext}')
101
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
102
+ print('Total number of videos:', len(filenames))
103
+ pool = Pool(opt.workers)
104
+ args_list = cycle([opt])
105
+ device_ids = opt.device_ids.split(",")
106
+ device_ids = cycle(device_ids)
107
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
108
+ None
src/face3d/extract_kp_videos_safe.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import glob
5
+ import argparse
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from tqdm import tqdm
10
+ from itertools import cycle
11
+ from torch.multiprocessing import Pool, Process, set_start_method
12
+
13
+ from facexlib.alignment import landmark_98_to_68
14
+ from facexlib.detection import init_detection_model
15
+
16
+ from facexlib.utils import load_file_from_url
17
+ from src.face3d.util.my_awing_arch import FAN
18
+
19
+ def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
20
+ if model_name == 'awing_fan':
21
+ model = FAN(num_modules=4, num_landmarks=98, device=device)
22
+ model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
23
+ else:
24
+ raise NotImplementedError(f'{model_name} is not implemented.')
25
+
26
+ model_path = load_file_from_url(
27
+ url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
28
+ model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
29
+ model.eval()
30
+ model = model.to(device)
31
+ return model
32
+
33
+
34
+ class KeypointExtractor():
35
+ def __init__(self, device='cuda'):
36
+
37
+ ### gfpgan/weights
38
+ try:
39
+ import webui # in webui
40
+ root_path = 'extensions/SadTalker/gfpgan/weights'
41
+
42
+ except:
43
+ root_path = 'gfpgan/weights'
44
+
45
+ self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
46
+ self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
47
+
48
+ def extract_keypoint(self, images, name=None, info=True):
49
+ if isinstance(images, list):
50
+ keypoints = []
51
+ if info:
52
+ i_range = tqdm(images,desc='landmark Det:')
53
+ else:
54
+ i_range = images
55
+
56
+ for image in i_range:
57
+ current_kp = self.extract_keypoint(image)
58
+ # current_kp = self.detector.get_landmarks(np.array(image))
59
+ if np.mean(current_kp) == -1 and keypoints:
60
+ keypoints.append(keypoints[-1])
61
+ else:
62
+ keypoints.append(current_kp[None])
63
+
64
+ keypoints = np.concatenate(keypoints, 0)
65
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
66
+ return keypoints
67
+ else:
68
+ while True:
69
+ try:
70
+ with torch.no_grad():
71
+ # face detection -> face alignment.
72
+ img = np.array(images)
73
+ bboxes = self.det_net.detect_faces(images, 0.97)
74
+
75
+ bboxes = bboxes[0]
76
+ img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
77
+
78
+ keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
79
+
80
+ #### keypoints to the original location
81
+ keypoints[:,0] += int(bboxes[0])
82
+ keypoints[:,1] += int(bboxes[1])
83
+
84
+ break
85
+ except RuntimeError as e:
86
+ if str(e).startswith('CUDA'):
87
+ print("Warning: out of memory, sleep for 1s")
88
+ time.sleep(1)
89
+ else:
90
+ print(e)
91
+ break
92
+ except TypeError:
93
+ print('No face detected in this image')
94
+ shape = [68, 2]
95
+ keypoints = -1. * np.ones(shape)
96
+ break
97
+ if name is not None:
98
+ np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
99
+ return keypoints
100
+
101
+ def read_video(filename):
102
+ frames = []
103
+ cap = cv2.VideoCapture(filename)
104
+ while cap.isOpened():
105
+ ret, frame = cap.read()
106
+ if ret:
107
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+ frame = Image.fromarray(frame)
109
+ frames.append(frame)
110
+ else:
111
+ break
112
+ cap.release()
113
+ return frames
114
+
115
+ def run(data):
116
+ filename, opt, device = data
117
+ os.environ['CUDA_VISIBLE_DEVICES'] = device
118
+ kp_extractor = KeypointExtractor()
119
+ images = read_video(filename)
120
+ name = filename.split('/')[-2:]
121
+ os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
122
+ kp_extractor.extract_keypoint(
123
+ images,
124
+ name=os.path.join(opt.output_dir, name[-2], name[-1])
125
+ )
126
+
127
+ if __name__ == '__main__':
128
+ set_start_method('spawn')
129
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
130
+ parser.add_argument('--input_dir', type=str, help='the folder of the input files')
131
+ parser.add_argument('--output_dir', type=str, help='the folder of the output files')
132
+ parser.add_argument('--device_ids', type=str, default='0,1')
133
+ parser.add_argument('--workers', type=int, default=4)
134
+
135
+ opt = parser.parse_args()
136
+ filenames = list()
137
+ VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
138
+ VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
139
+ extensions = VIDEO_EXTENSIONS
140
+
141
+ for ext in extensions:
142
+ os.listdir(f'{opt.input_dir}')
143
+ print(f'{opt.input_dir}/*.{ext}')
144
+ filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
145
+ print('Total number of videos:', len(filenames))
146
+ pool = Pool(opt.workers)
147
+ args_list = cycle([opt])
148
+ device_ids = opt.device_ids.split(",")
149
+ device_ids = cycle(device_ids)
150
+ for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
151
+ None
src/face3d/models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from src.face3d.models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "face3d.models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
src/face3d/models/arcface_torch/README.md ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributed Arcface Training in Pytorch
2
+
3
+ This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
4
+ identity on a single server.
5
+
6
+ ## Requirements
7
+
8
+ - Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
9
+ - `pip install -r requirements.txt`.
10
+ - Download the dataset
11
+ from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
12
+ .
13
+
14
+ ## How to Training
15
+
16
+ To train a model, run `train.py` with the path to the configs:
17
+
18
+ ### 1. Single node, 8 GPUs:
19
+
20
+ ```shell
21
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
22
+ ```
23
+
24
+ ### 2. Multiple nodes, each node 8 GPUs:
25
+
26
+ Node 0:
27
+
28
+ ```shell
29
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
30
+ ```
31
+
32
+ Node 1:
33
+
34
+ ```shell
35
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
36
+ ```
37
+
38
+ ### 3.Training resnet2060 with 8 GPUs:
39
+
40
+ ```shell
41
+ python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
42
+ ```
43
+
44
+ ## Model Zoo
45
+
46
+ - The models are available for non-commercial research purposes only.
47
+ - All models can be found in here.
48
+ - [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
49
+ - [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
50
+
51
+ ### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
52
+
53
+ ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
54
+ recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
55
+ As the result, we can evaluate the FAIR performance for different algorithms.
56
+
57
+ For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
58
+ globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
59
+
60
+ For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
61
+ Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
62
+ There are totally 13,928 positive pairs and 96,983,824 negative pairs.
63
+
64
+ | Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
65
+ | :---: | :--- | :--- | :--- |:--- |:--- |
66
+ | MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
67
+ | Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
68
+ | MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
69
+ | Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
70
+ | MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
71
+ | Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
72
+ | MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
73
+ | Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
74
+ | MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
75
+ | Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
76
+
77
+ ### Performance on IJB-C and Verification Datasets
78
+
79
+ | Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
80
+ | :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
81
+ | MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
82
+ | MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
83
+ | MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
84
+ | MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
85
+ | MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
86
+ | Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
87
+ | Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
88
+ | Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
89
+ | Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
90
+
91
+ [comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
92
+
93
+
94
+ ## [Speed Benchmark](docs/speed_benchmark.md)
95
+
96
+ **Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
97
+ classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
98
+ accuracy with several times faster training performance and smaller GPU memory.
99
+ Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
100
+ sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
101
+ sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
102
+ we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
103
+ training and mixed precision training.
104
+
105
+ ![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
106
+
107
+ More details see
108
+ [speed_benchmark.md](docs/speed_benchmark.md) in docs.
109
+
110
+ ### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
111
+
112
+ `-` means training failed because of gpu memory limitations.
113
+
114
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
115
+ | :--- | :--- | :--- | :--- |
116
+ |125000 | 4681 | 4824 | 5004 |
117
+ |1400000 | **1672** | 3043 | 4738 |
118
+ |5500000 | **-** | **1389** | 3975 |
119
+ |8000000 | **-** | **-** | 3565 |
120
+ |16000000 | **-** | **-** | 2679 |
121
+ |29000000 | **-** | **-** | **1855** |
122
+
123
+ ### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
124
+
125
+ | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
126
+ | :--- | :--- | :--- | :--- |
127
+ |125000 | 7358 | 5306 | 4868 |
128
+ |1400000 | 32252 | 11178 | 6056 |
129
+ |5500000 | **-** | 32188 | 9854 |
130
+ |8000000 | **-** | **-** | 12310 |
131
+ |16000000 | **-** | **-** | 19950 |
132
+ |29000000 | **-** | **-** | 32324 |
133
+
134
+ ## Evaluation ICCV2021-MFR and IJB-C
135
+
136
+ More details see [eval.md](docs/eval.md) in docs.
137
+
138
+ ## Test
139
+
140
+ We tested many versions of PyTorch. Please create an issue if you are having trouble.
141
+
142
+ - [x] torch 1.6.0
143
+ - [x] torch 1.7.1
144
+ - [x] torch 1.8.0
145
+ - [x] torch 1.9.0
146
+
147
+ ## Citation
148
+
149
+ ```
150
+ @inproceedings{deng2019arcface,
151
+ title={Arcface: Additive angular margin loss for deep face recognition},
152
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
153
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
154
+ pages={4690--4699},
155
+ year={2019}
156
+ }
157
+ @inproceedings{an2020partical_fc,
158
+ title={Partial FC: Training 10 Million Identities on a Single Machine},
159
+ author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
160
+ Zhang, Debing and Fu Ying},
161
+ booktitle={Arxiv 2010.05222},
162
+ year={2020}
163
+ }
164
+ ```
src/face3d/models/arcface_torch/backbones/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2
+ from .mobilefacenet import get_mbf
3
+
4
+
5
+ def get_model(name, **kwargs):
6
+ # resnet
7
+ if name == "r18":
8
+ return iresnet18(False, **kwargs)
9
+ elif name == "r34":
10
+ return iresnet34(False, **kwargs)
11
+ elif name == "r50":
12
+ return iresnet50(False, **kwargs)
13
+ elif name == "r100":
14
+ return iresnet100(False, **kwargs)
15
+ elif name == "r200":
16
+ return iresnet200(False, **kwargs)
17
+ elif name == "r2060":
18
+ from .iresnet2060 import iresnet2060
19
+ return iresnet2060(False, **kwargs)
20
+ elif name == "mbf":
21
+ fp16 = kwargs.get("fp16", False)
22
+ num_features = kwargs.get("num_features", 512)
23
+ return get_mbf(fp16=fp16, num_features=num_features)
24
+ else:
25
+ raise ValueError()
src/face3d/models/arcface_torch/backbones/iresnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(in_planes,
10
+ out_planes,
11
+ kernel_size=3,
12
+ stride=stride,
13
+ padding=dilation,
14
+ groups=groups,
15
+ bias=False,
16
+ dilation=dilation)
17
+
18
+
19
+ def conv1x1(in_planes, out_planes, stride=1):
20
+ """1x1 convolution"""
21
+ return nn.Conv2d(in_planes,
22
+ out_planes,
23
+ kernel_size=1,
24
+ stride=stride,
25
+ bias=False)
26
+
27
+
28
+ class IBasicBlock(nn.Module):
29
+ expansion = 1
30
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
31
+ groups=1, base_width=64, dilation=1):
32
+ super(IBasicBlock, self).__init__()
33
+ if groups != 1 or base_width != 64:
34
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35
+ if dilation > 1:
36
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
38
+ self.conv1 = conv3x3(inplanes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
40
+ self.prelu = nn.PReLU(planes)
41
+ self.conv2 = conv3x3(planes, planes, stride)
42
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+ out = self.bn1(x)
49
+ out = self.conv1(out)
50
+ out = self.bn2(out)
51
+ out = self.prelu(out)
52
+ out = self.conv2(out)
53
+ out = self.bn3(out)
54
+ if self.downsample is not None:
55
+ identity = self.downsample(x)
56
+ out += identity
57
+ return out
58
+
59
+
60
+ class IResNet(nn.Module):
61
+ fc_scale = 7 * 7
62
+ def __init__(self,
63
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
64
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
65
+ super(IResNet, self).__init__()
66
+ self.fp16 = fp16
67
+ self.inplanes = 64
68
+ self.dilation = 1
69
+ if replace_stride_with_dilation is None:
70
+ replace_stride_with_dilation = [False, False, False]
71
+ if len(replace_stride_with_dilation) != 3:
72
+ raise ValueError("replace_stride_with_dilation should be None "
73
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74
+ self.groups = groups
75
+ self.base_width = width_per_group
76
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
78
+ self.prelu = nn.PReLU(self.inplanes)
79
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
80
+ self.layer2 = self._make_layer(block,
81
+ 128,
82
+ layers[1],
83
+ stride=2,
84
+ dilate=replace_stride_with_dilation[0])
85
+ self.layer3 = self._make_layer(block,
86
+ 256,
87
+ layers[2],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[1])
90
+ self.layer4 = self._make_layer(block,
91
+ 512,
92
+ layers[3],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[2])
95
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
96
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
97
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
98
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
99
+ nn.init.constant_(self.features.weight, 1.0)
100
+ self.features.weight.requires_grad = False
101
+
102
+ for m in self.modules():
103
+ if isinstance(m, nn.Conv2d):
104
+ nn.init.normal_(m.weight, 0, 0.1)
105
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
106
+ nn.init.constant_(m.weight, 1)
107
+ nn.init.constant_(m.bias, 0)
108
+
109
+ if zero_init_residual:
110
+ for m in self.modules():
111
+ if isinstance(m, IBasicBlock):
112
+ nn.init.constant_(m.bn2.weight, 0)
113
+
114
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
115
+ downsample = None
116
+ previous_dilation = self.dilation
117
+ if dilate:
118
+ self.dilation *= stride
119
+ stride = 1
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ conv1x1(self.inplanes, planes * block.expansion, stride),
123
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
124
+ )
125
+ layers = []
126
+ layers.append(
127
+ block(self.inplanes, planes, stride, downsample, self.groups,
128
+ self.base_width, previous_dilation))
129
+ self.inplanes = planes * block.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(
132
+ block(self.inplanes,
133
+ planes,
134
+ groups=self.groups,
135
+ base_width=self.base_width,
136
+ dilation=self.dilation))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ with torch.cuda.amp.autocast(self.fp16):
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.prelu(x)
145
+ x = self.layer1(x)
146
+ x = self.layer2(x)
147
+ x = self.layer3(x)
148
+ x = self.layer4(x)
149
+ x = self.bn2(x)
150
+ x = torch.flatten(x, 1)
151
+ x = self.dropout(x)
152
+ x = self.fc(x.float() if self.fp16 else x)
153
+ x = self.features(x)
154
+ return x
155
+
156
+
157
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
158
+ model = IResNet(block, layers, **kwargs)
159
+ if pretrained:
160
+ raise ValueError()
161
+ return model
162
+
163
+
164
+ def iresnet18(pretrained=False, progress=True, **kwargs):
165
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
166
+ progress, **kwargs)
167
+
168
+
169
+ def iresnet34(pretrained=False, progress=True, **kwargs):
170
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
171
+ progress, **kwargs)
172
+
173
+
174
+ def iresnet50(pretrained=False, progress=True, **kwargs):
175
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
176
+ progress, **kwargs)
177
+
178
+
179
+ def iresnet100(pretrained=False, progress=True, **kwargs):
180
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
181
+ progress, **kwargs)
182
+
183
+
184
+ def iresnet200(pretrained=False, progress=True, **kwargs):
185
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
186
+ progress, **kwargs)
187
+
src/face3d/models/arcface_torch/backbones/iresnet2060.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ assert torch.__version__ >= "1.8.1"
5
+ from torch.utils.checkpoint import checkpoint_sequential
6
+
7
+ __all__ = ['iresnet2060']
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ stride=stride,
16
+ padding=dilation,
17
+ groups=groups,
18
+ bias=False,
19
+ dilation=dilation)
20
+
21
+
22
+ def conv1x1(in_planes, out_planes, stride=1):
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes,
25
+ out_planes,
26
+ kernel_size=1,
27
+ stride=stride,
28
+ bias=False)
29
+
30
+
31
+ class IBasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
35
+ groups=1, base_width=64, dilation=1):
36
+ super(IBasicBlock, self).__init__()
37
+ if groups != 1 or base_width != 64:
38
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39
+ if dilation > 1:
40
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42
+ self.conv1 = conv3x3(inplanes, planes)
43
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44
+ self.prelu = nn.PReLU(planes)
45
+ self.conv2 = conv3x3(planes, planes, stride)
46
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47
+ self.downsample = downsample
48
+ self.stride = stride
49
+
50
+ def forward(self, x):
51
+ identity = x
52
+ out = self.bn1(x)
53
+ out = self.conv1(out)
54
+ out = self.bn2(out)
55
+ out = self.prelu(out)
56
+ out = self.conv2(out)
57
+ out = self.bn3(out)
58
+ if self.downsample is not None:
59
+ identity = self.downsample(x)
60
+ out += identity
61
+ return out
62
+
63
+
64
+ class IResNet(nn.Module):
65
+ fc_scale = 7 * 7
66
+
67
+ def __init__(self,
68
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
69
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70
+ super(IResNet, self).__init__()
71
+ self.fp16 = fp16
72
+ self.inplanes = 64
73
+ self.dilation = 1
74
+ if replace_stride_with_dilation is None:
75
+ replace_stride_with_dilation = [False, False, False]
76
+ if len(replace_stride_with_dilation) != 3:
77
+ raise ValueError("replace_stride_with_dilation should be None "
78
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79
+ self.groups = groups
80
+ self.base_width = width_per_group
81
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83
+ self.prelu = nn.PReLU(self.inplanes)
84
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85
+ self.layer2 = self._make_layer(block,
86
+ 128,
87
+ layers[1],
88
+ stride=2,
89
+ dilate=replace_stride_with_dilation[0])
90
+ self.layer3 = self._make_layer(block,
91
+ 256,
92
+ layers[2],
93
+ stride=2,
94
+ dilate=replace_stride_with_dilation[1])
95
+ self.layer4 = self._make_layer(block,
96
+ 512,
97
+ layers[3],
98
+ stride=2,
99
+ dilate=replace_stride_with_dilation[2])
100
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
102
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104
+ nn.init.constant_(self.features.weight, 1.0)
105
+ self.features.weight.requires_grad = False
106
+
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.normal_(m.weight, 0, 0.1)
110
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111
+ nn.init.constant_(m.weight, 1)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ if zero_init_residual:
115
+ for m in self.modules():
116
+ if isinstance(m, IBasicBlock):
117
+ nn.init.constant_(m.bn2.weight, 0)
118
+
119
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120
+ downsample = None
121
+ previous_dilation = self.dilation
122
+ if dilate:
123
+ self.dilation *= stride
124
+ stride = 1
125
+ if stride != 1 or self.inplanes != planes * block.expansion:
126
+ downsample = nn.Sequential(
127
+ conv1x1(self.inplanes, planes * block.expansion, stride),
128
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129
+ )
130
+ layers = []
131
+ layers.append(
132
+ block(self.inplanes, planes, stride, downsample, self.groups,
133
+ self.base_width, previous_dilation))
134
+ self.inplanes = planes * block.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(
137
+ block(self.inplanes,
138
+ planes,
139
+ groups=self.groups,
140
+ base_width=self.base_width,
141
+ dilation=self.dilation))
142
+
143
+ return nn.Sequential(*layers)
144
+
145
+ def checkpoint(self, func, num_seg, x):
146
+ if self.training:
147
+ return checkpoint_sequential(func, num_seg, x)
148
+ else:
149
+ return func(x)
150
+
151
+ def forward(self, x):
152
+ with torch.cuda.amp.autocast(self.fp16):
153
+ x = self.conv1(x)
154
+ x = self.bn1(x)
155
+ x = self.prelu(x)
156
+ x = self.layer1(x)
157
+ x = self.checkpoint(self.layer2, 20, x)
158
+ x = self.checkpoint(self.layer3, 100, x)
159
+ x = self.layer4(x)
160
+ x = self.bn2(x)
161
+ x = torch.flatten(x, 1)
162
+ x = self.dropout(x)
163
+ x = self.fc(x.float() if self.fp16 else x)
164
+ x = self.features(x)
165
+ return x
166
+
167
+
168
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169
+ model = IResNet(block, layers, **kwargs)
170
+ if pretrained:
171
+ raise ValueError()
172
+ return model
173
+
174
+
175
+ def iresnet2060(pretrained=False, progress=True, **kwargs):
176
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
src/face3d/models/arcface_torch/backbones/mobilefacenet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3
+ Original author cavalleria
4
+ '''
5
+
6
+ import torch.nn as nn
7
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8
+ import torch
9
+
10
+
11
+ class Flatten(Module):
12
+ def forward(self, x):
13
+ return x.view(x.size(0), -1)
14
+
15
+
16
+ class ConvBlock(Module):
17
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18
+ super(ConvBlock, self).__init__()
19
+ self.layers = nn.Sequential(
20
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21
+ BatchNorm2d(num_features=out_c),
22
+ PReLU(num_parameters=out_c)
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.layers(x)
27
+
28
+
29
+ class LinearBlock(Module):
30
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31
+ super(LinearBlock, self).__init__()
32
+ self.layers = nn.Sequential(
33
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34
+ BatchNorm2d(num_features=out_c)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layers(x)
39
+
40
+
41
+ class DepthWise(Module):
42
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43
+ super(DepthWise, self).__init__()
44
+ self.residual = residual
45
+ self.layers = nn.Sequential(
46
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49
+ )
50
+
51
+ def forward(self, x):
52
+ short_cut = None
53
+ if self.residual:
54
+ short_cut = x
55
+ x = self.layers(x)
56
+ if self.residual:
57
+ output = short_cut + x
58
+ else:
59
+ output = x
60
+ return output
61
+
62
+
63
+ class Residual(Module):
64
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65
+ super(Residual, self).__init__()
66
+ modules = []
67
+ for _ in range(num_block):
68
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69
+ self.layers = Sequential(*modules)
70
+
71
+ def forward(self, x):
72
+ return self.layers(x)
73
+
74
+
75
+ class GDC(Module):
76
+ def __init__(self, embedding_size):
77
+ super(GDC, self).__init__()
78
+ self.layers = nn.Sequential(
79
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80
+ Flatten(),
81
+ Linear(512, embedding_size, bias=False),
82
+ BatchNorm1d(embedding_size))
83
+
84
+ def forward(self, x):
85
+ return self.layers(x)
86
+
87
+
88
+ class MobileFaceNet(Module):
89
+ def __init__(self, fp16=False, num_features=512):
90
+ super(MobileFaceNet, self).__init__()
91
+ scale = 2
92
+ self.fp16 = fp16
93
+ self.layers = nn.Sequential(
94
+ ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
95
+ ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
96
+ DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
97
+ Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
98
+ DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
99
+ Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
100
+ DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
101
+ Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
102
+ )
103
+ self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104
+ self.features = GDC(num_features)
105
+ self._initialize_weights()
106
+
107
+ def _initialize_weights(self):
108
+ for m in self.modules():
109
+ if isinstance(m, nn.Conv2d):
110
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111
+ if m.bias is not None:
112
+ m.bias.data.zero_()
113
+ elif isinstance(m, nn.BatchNorm2d):
114
+ m.weight.data.fill_(1)
115
+ m.bias.data.zero_()
116
+ elif isinstance(m, nn.Linear):
117
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
118
+ if m.bias is not None:
119
+ m.bias.data.zero_()
120
+
121
+ def forward(self, x):
122
+ with torch.cuda.amp.autocast(self.fp16):
123
+ x = self.layers(x)
124
+ x = self.conv_sep(x.float() if self.fp16 else x)
125
+ x = self.features(x)
126
+ return x
127
+
128
+
129
+ def get_mbf(fp16, num_features):
130
+ return MobileFaceNet(fp16, num_features)
src/face3d/models/arcface_torch/configs/3millions.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 1.0
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
src/face3d/models/arcface_torch/configs/3millions_pfc.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # configs for test speed
4
+
5
+ config = edict()
6
+ config.loss = "arcface"
7
+ config.network = "r50"
8
+ config.resume = False
9
+ config.output = None
10
+ config.embedding_size = 512
11
+ config.sample_rate = 0.1
12
+ config.fp16 = True
13
+ config.momentum = 0.9
14
+ config.weight_decay = 5e-4
15
+ config.batch_size = 128
16
+ config.lr = 0.1 # batch size is 512
17
+
18
+ config.rec = "synthetic"
19
+ config.num_classes = 300 * 10000
20
+ config.num_epoch = 30
21
+ config.warmup_epoch = -1
22
+ config.decay_epoch = [10, 16, 22]
23
+ config.val_targets = []
src/face3d/models/arcface_torch/configs/__init__.py ADDED
File without changes
src/face3d/models/arcface_torch/configs/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = "ms1mv3_arcface_r50"
12
+
13
+ config.dataset = "ms1m-retinaface-t1"
14
+ config.embedding_size = 512
15
+ config.sample_rate = 1
16
+ config.fp16 = False
17
+ config.momentum = 0.9
18
+ config.weight_decay = 5e-4
19
+ config.batch_size = 128
20
+ config.lr = 0.1 # batch size is 512
21
+
22
+ if config.dataset == "emore":
23
+ config.rec = "/train_tmp/faces_emore"
24
+ config.num_classes = 85742
25
+ config.num_image = 5822653
26
+ config.num_epoch = 16
27
+ config.warmup_epoch = -1
28
+ config.decay_epoch = [8, 14, ]
29
+ config.val_targets = ["lfw", ]
30
+
31
+ elif config.dataset == "ms1m-retinaface-t1":
32
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
33
+ config.num_classes = 93431
34
+ config.num_image = 5179510
35
+ config.num_epoch = 25
36
+ config.warmup_epoch = -1
37
+ config.decay_epoch = [11, 17, 22]
38
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
39
+
40
+ elif config.dataset == "glint360k":
41
+ config.rec = "/train_tmp/glint360k"
42
+ config.num_classes = 360232
43
+ config.num_image = 17091657
44
+ config.num_epoch = 20
45
+ config.warmup_epoch = -1
46
+ config.decay_epoch = [8, 12, 15, 18]
47
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
48
+
49
+ elif config.dataset == "webface":
50
+ config.rec = "/train_tmp/faces_webface_112x112"
51
+ config.num_classes = 10572
52
+ config.num_image = "forget"
53
+ config.num_epoch = 34
54
+ config.warmup_epoch = -1
55
+ config.decay_epoch = [20, 28, 32]
56
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
src/face3d/models/arcface_torch/configs/glint360k_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 0.1
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
src/face3d/models/arcface_torch/configs/glint360k_r100.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r100"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
src/face3d/models/arcface_torch/configs/glint360k_r18.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r18"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
src/face3d/models/arcface_torch/configs/glint360k_r34.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r34"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
src/face3d/models/arcface_torch/configs/glint360k_r50.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "cosface"
9
+ config.network = "r50"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 5e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/glint360k"
21
+ config.num_classes = 360232
22
+ config.num_image = 17091657
23
+ config.num_epoch = 20
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [8, 12, 15, 18]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as edict
2
+
3
+ # make training faster
4
+ # our RAM is 256G
5
+ # mount -t tmpfs -o size=140G tmpfs /train_tmp
6
+
7
+ config = edict()
8
+ config.loss = "arcface"
9
+ config.network = "mbf"
10
+ config.resume = False
11
+ config.output = None
12
+ config.embedding_size = 512
13
+ config.sample_rate = 1.0
14
+ config.fp16 = True
15
+ config.momentum = 0.9
16
+ config.weight_decay = 2e-4
17
+ config.batch_size = 128
18
+ config.lr = 0.1 # batch size is 512
19
+
20
+ config.rec = "/train_tmp/ms1m-retinaface-t1"
21
+ config.num_classes = 93431
22
+ config.num_image = 5179510
23
+ config.num_epoch = 30
24
+ config.warmup_epoch = -1
25
+ config.decay_epoch = [10, 20, 25]
26
+ config.val_targets = ["lfw", "cfp_fp", "agedb_30"]