marlenezw commited on
Commit
c46f04a
1 Parent(s): b9aa58b

Upload marlenezw/audio-driven-animations/MakeItTalk with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/.gitignore +8 -0
  3. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/LICENSE +201 -0
  4. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/README.md +82 -0
  5. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__init__.py +0 -0
  6. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-37.pyc +0 -0
  7. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-39.pyc +0 -0
  8. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/ckpt/.gitkeep +0 -0
  9. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__init__.py +0 -0
  10. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-37.pyc +0 -0
  11. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-39.pyc +0 -0
  12. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-37.pyc +0 -0
  13. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-39.pyc +0 -0
  14. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-37.pyc +0 -0
  15. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-39.pyc +0 -0
  16. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/coord_conv.py +157 -0
  17. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/dataloader.py +368 -0
  18. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/evaler.py +151 -0
  19. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/models.py +228 -0
  20. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/eval.py +77 -0
  21. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw.png +3 -0
  22. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw_table.png +3 -0
  23. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/requirements.txt +12 -0
  24. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/scripts/eval_wflw.sh +10 -0
  25. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__init__.py +0 -0
  26. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-37.pyc +0 -0
  27. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  28. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-37.pyc +0 -0
  29. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-39.pyc +0 -0
  30. marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/utils.py +354 -0
  31. marlenezw/audio-driven-animations/MakeItTalk/__init__.py +0 -0
  32. marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-37.pyc +0 -0
  33. marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-39.pyc +0 -0
  34. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/CODEOWNERS +1 -0
  35. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/LICENCE.txt +21 -0
  36. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/README.md +98 -0
  37. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.py +0 -0
  38. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.pyc +0 -0
  39. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/__init__.cpython-36.pyc +0 -0
  40. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/data_loading_functions.cpython-36.pyc +0 -0
  41. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deep_heatmaps_model_fusion_net.cpython-36.pyc +0 -0
  42. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deformation_functions.cpython-36.pyc +0 -0
  43. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/logging_functions.cpython-36.pyc +0 -0
  44. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/menpo_functions.cpython-36.pyc +0 -0
  45. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/ops.cpython-36.pyc +0 -0
  46. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/pdm_clm_functions.cpython-36.pyc +0 -0
  47. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/crop_training_set.py +38 -0
  48. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.py +161 -0
  49. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.pyc +0 -0
  50. marlenezw/audio-driven-animations/MakeItTalk/face_of_art/deep_heatmaps_model_fusion_net.py +872 -0
.gitattributes CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  marlenezw/audio-driven-animations/MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
36
  MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  marlenezw/audio-driven-animations/MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
36
  MakeItTalk/examples/ckpt filter=lfs diff=lfs merge=lfs -text
37
+ marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw.png filter=lfs diff=lfs merge=lfs -text
38
+ marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw_table.png filter=lfs diff=lfs merge=lfs -text
39
+ marlenezw/audio-driven-animations/MakeItTalk/face_of_art/old/teaser.png filter=lfs diff=lfs merge=lfs -text
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Python generated files
2
+ *.pyc
3
+
4
+ # Project related files
5
+ ckpt/*.pth
6
+ dataset/*
7
+ !dataset/!.py
8
+ experiments/*
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/README.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AdaptiveWingLoss
2
+ ## [arXiv](https://arxiv.org/abs/1904.07399)
3
+ Pytorch Implementation of Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression.
4
+
5
+ <img src='images/wflw.png' width="1000px">
6
+
7
+ ## Update Logs:
8
+ ### October 28, 2019
9
+ * Pretrained Model and evaluation code on WFLW dataset is released.
10
+
11
+ ## Installation
12
+ #### Note: Code was originally developed under Python2.X and Pytorch 0.4. This released version was revisioned from original code and was tested on Python3.5.7 and Pytorch 1.3.0.
13
+
14
+ Install system requirements:
15
+ ```
16
+ sudo apt-get install python3-dev python3-pip python3-tk libglib2.0-0
17
+ ```
18
+
19
+ Install python dependencies:
20
+ ```
21
+ pip3 install -r requirements.txt
22
+ ```
23
+
24
+ ## Run Evaluation on WFLW dataset
25
+ 1. Download and process WFLW dataset
26
+ * Download WFLW dataset and annotation from [Here](https://wywu.github.io/projects/LAB/WFLW.html).
27
+ * Unzip WFLW dataset and annotations and move files into ```./dataset``` directory. Your directory should look like this:
28
+ ```
29
+ AdaptiveWingLoss
30
+ └───dataset
31
+
32
+ └───WFLW_annotations
33
+ │ └───list_98pt_rect_attr_train_test
34
+ │ │
35
+ │ └───list_98pt_test
36
+
37
+ └───WFLW_images
38
+ └───0--Parade
39
+
40
+ └───...
41
+ ```
42
+ * Inside ```./dataset``` directory, run:
43
+ ```
44
+ python convert_WFLW.py
45
+ ```
46
+ A new directory ```./dataset/WFLW_test``` should be generated with 2500 processed testing images and corresponding landmarks.
47
+
48
+ 2. Download pretrained model from [Google Drive](https://drive.google.com/file/d/1HZaSjLoorQ4QCEx7PRTxOmg0bBPYSqhH/view?usp=sharing) and put it in ```./ckpt``` directory.
49
+
50
+ 3. Within ```./Scripts``` directory, run following command:
51
+ ```
52
+ sh eval_wflw.sh
53
+ ```
54
+
55
+ <img src='images/wflw_table.png' width="800px">
56
+ *GTBbox indicates the ground truth landmarks are used as bounding box to crop faces.
57
+
58
+ ## Future Plans
59
+ - [x] Release evaluation code and pretrained model on WFLW dataset.
60
+
61
+ - [ ] Release training code on WFLW dataset.
62
+
63
+ - [ ] Release pretrained model and code on 300W, AFLW and COFW dataset.
64
+
65
+ - [ ] Replease facial landmark detection API
66
+
67
+
68
+ ## Citation
69
+ If you find this useful for your research, please cite the following paper.
70
+
71
+ ```
72
+ @InProceedings{Wang_2019_ICCV,
73
+ author = {Wang, Xinyao and Bo, Liefeng and Fuxin, Li},
74
+ title = {Adaptive Wing Loss for Robust Face Alignment via Heatmap Regression},
75
+ booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
76
+ month = {October},
77
+ year = {2019}
78
+ }
79
+ ```
80
+
81
+ ## Acknowledgments
82
+ This repository borrows or partially modifies hourglass model and data processing code from [face alignment](https://github.com/1adrianb/face-alignment) and [pose-hg-train](https://github.com/princeton-vl/pose-hg-train).
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__init__.py ADDED
File without changes
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (164 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (179 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/ckpt/.gitkeep ADDED
File without changes
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__init__.py ADDED
File without changes
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (169 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (184 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-37.pyc ADDED
Binary file (4.33 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/coord_conv.cpython-39.pyc ADDED
Binary file (4.38 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-37.pyc ADDED
Binary file (5.77 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/__pycache__/models.cpython-39.pyc ADDED
Binary file (5.83 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/coord_conv.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ class AddCoordsTh(nn.Module):
9
+ def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
10
+ super(AddCoordsTh, self).__init__()
11
+ self.x_dim = x_dim
12
+ self.y_dim = y_dim
13
+ self.with_r = with_r
14
+ self.with_boundary = with_boundary
15
+
16
+ def forward(self, input_tensor, heatmap=None):
17
+ """
18
+ input_tensor: (batch, c, x_dim, y_dim)
19
+ """
20
+ batch_size_tensor = input_tensor.shape[0]
21
+
22
+ xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(device)
23
+ xx_ones = xx_ones.unsqueeze(-1)
24
+
25
+ xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(device)
26
+ xx_range = xx_range.unsqueeze(1)
27
+
28
+ xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
29
+ xx_channel = xx_channel.unsqueeze(-1)
30
+
31
+
32
+ yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(device)
33
+ yy_ones = yy_ones.unsqueeze(1)
34
+
35
+ yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(device)
36
+ yy_range = yy_range.unsqueeze(-1)
37
+
38
+ yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
39
+ yy_channel = yy_channel.unsqueeze(-1)
40
+
41
+ xx_channel = xx_channel.permute(0, 3, 2, 1)
42
+ yy_channel = yy_channel.permute(0, 3, 2, 1)
43
+
44
+ xx_channel = xx_channel / (self.x_dim - 1)
45
+ yy_channel = yy_channel / (self.y_dim - 1)
46
+
47
+ xx_channel = xx_channel * 2 - 1
48
+ yy_channel = yy_channel * 2 - 1
49
+
50
+ xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
51
+ yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
52
+
53
+ if self.with_boundary and type(heatmap) != type(None):
54
+ boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
55
+ 0.0, 1.0)
56
+
57
+ zero_tensor = torch.zeros_like(xx_channel)
58
+ xx_boundary_channel = torch.where(boundary_channel>0.05,
59
+ xx_channel, zero_tensor)
60
+ yy_boundary_channel = torch.where(boundary_channel>0.05,
61
+ yy_channel, zero_tensor)
62
+ if self.with_boundary and type(heatmap) != type(None):
63
+ xx_boundary_channel = xx_boundary_channel.to(device)
64
+ yy_boundary_channel = yy_boundary_channel.to(device)
65
+
66
+ ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
67
+
68
+
69
+ if self.with_r:
70
+ rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
71
+ rr = rr / torch.max(rr)
72
+ ret = torch.cat([ret, rr], dim=1)
73
+
74
+ if self.with_boundary and type(heatmap) != type(None):
75
+ ret = torch.cat([ret, xx_boundary_channel,
76
+ yy_boundary_channel], dim=1)
77
+ return ret
78
+
79
+
80
+ class CoordConvTh(nn.Module):
81
+ """CoordConv layer as in the paper."""
82
+ def __init__(self, x_dim, y_dim, with_r, with_boundary,
83
+ in_channels, first_one=False, *args, **kwargs):
84
+ super(CoordConvTh, self).__init__()
85
+ self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
86
+ with_boundary=with_boundary)
87
+ in_channels += 2
88
+ if with_r:
89
+ in_channels += 1
90
+ if with_boundary and not first_one:
91
+ in_channels += 2
92
+ self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
93
+
94
+ def forward(self, input_tensor, heatmap=None):
95
+ ret = self.addcoords(input_tensor, heatmap)
96
+ last_channel = ret[:, -2:, :, :]
97
+ ret = self.conv(ret)
98
+ return ret, last_channel
99
+
100
+
101
+ '''
102
+ An alternative implementation for PyTorch with auto-infering the x-y dimensions.
103
+ '''
104
+ class AddCoords(nn.Module):
105
+
106
+ def __init__(self, with_r=False):
107
+ super().__init__()
108
+ self.with_r = with_r
109
+
110
+ def forward(self, input_tensor):
111
+ """
112
+ Args:
113
+ input_tensor: shape(batch, channel, x_dim, y_dim)
114
+ """
115
+ batch_size, _, x_dim, y_dim = input_tensor.size()
116
+
117
+ xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
118
+ yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)
119
+
120
+ xx_channel = xx_channel / (x_dim - 1)
121
+ yy_channel = yy_channel / (y_dim - 1)
122
+
123
+ xx_channel = xx_channel * 2 - 1
124
+ yy_channel = yy_channel * 2 - 1
125
+
126
+ xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
127
+ yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
128
+
129
+ if input_tensor.is_cuda:
130
+ xx_channel = xx_channel.to(device)
131
+ yy_channel = yy_channel.to(device)
132
+
133
+ ret = torch.cat([
134
+ input_tensor,
135
+ xx_channel.type_as(input_tensor),
136
+ yy_channel.type_as(input_tensor)], dim=1)
137
+
138
+ if self.with_r:
139
+ rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
140
+ if input_tensor.is_cuda:
141
+ rr = rr.to(device)
142
+ ret = torch.cat([ret, rr], dim=1)
143
+
144
+ return ret
145
+
146
+
147
+ class CoordConv(nn.Module):
148
+
149
+ def __init__(self, in_channels, out_channels, with_r=False, **kwargs):
150
+ super().__init__()
151
+ self.addcoords = AddCoords(with_r=with_r)
152
+ self.conv = nn.Conv2d(in_channels + 2, out_channels, **kwargs)
153
+
154
+ def forward(self, x):
155
+ ret = self.addcoords(x)
156
+ ret = self.conv(ret)
157
+ return ret
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/dataloader.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import random
4
+ import glob
5
+ import torch
6
+ from skimage import io
7
+ from skimage import transform as ski_transform
8
+ from skimage.color import rgb2gray
9
+ import scipy.io as sio
10
+ from scipy import interpolate
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from torchvision import transforms, utils
15
+ from torchvision.transforms import Lambda, Compose
16
+ from torchvision.transforms.functional import adjust_brightness, adjust_contrast, adjust_saturation, adjust_hue
17
+ from utils.utils import cv_crop, cv_rotate, draw_gaussian, transform, power_transform, shuffle_lr, fig2data, generate_weight_map
18
+ from PIL import Image
19
+ import cv2
20
+ import copy
21
+ import math
22
+ from imgaug import augmenters as iaa
23
+
24
+
25
+ class AddBoundary(object):
26
+ def __init__(self, num_landmarks=68):
27
+ self.num_landmarks = num_landmarks
28
+
29
+ def __call__(self, sample):
30
+ landmarks_64 = np.floor(sample['landmarks'] / 4.0)
31
+ if self.num_landmarks == 68:
32
+ boundaries = {}
33
+ boundaries['cheek'] = landmarks_64[0:17]
34
+ boundaries['left_eyebrow'] = landmarks_64[17:22]
35
+ boundaries['right_eyebrow'] = landmarks_64[22:27]
36
+ boundaries['uper_left_eyelid'] = landmarks_64[36:40]
37
+ boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [36, 41, 40, 39]])
38
+ boundaries['upper_right_eyelid'] = landmarks_64[42:46]
39
+ boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [42, 47, 46, 45]])
40
+ boundaries['noise'] = landmarks_64[27:31]
41
+ boundaries['noise_bot'] = landmarks_64[31:36]
42
+ boundaries['upper_outer_lip'] = landmarks_64[48:55]
43
+ boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [60, 61, 62, 63, 64]])
44
+ boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [48, 59, 58, 57, 56, 55, 54]])
45
+ boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
46
+ elif self.num_landmarks == 98:
47
+ boundaries = {}
48
+ boundaries['cheek'] = landmarks_64[0:33]
49
+ boundaries['left_eyebrow'] = landmarks_64[33:38]
50
+ boundaries['right_eyebrow'] = landmarks_64[42:47]
51
+ boundaries['uper_left_eyelid'] = landmarks_64[60:65]
52
+ boundaries['lower_left_eyelid'] = np.array([landmarks_64[i] for i in [60, 67, 66, 65, 64]])
53
+ boundaries['upper_right_eyelid'] = landmarks_64[68:73]
54
+ boundaries['lower_right_eyelid'] = np.array([landmarks_64[i] for i in [68, 75, 74, 73, 72]])
55
+ boundaries['noise'] = landmarks_64[51:55]
56
+ boundaries['noise_bot'] = landmarks_64[55:60]
57
+ boundaries['upper_outer_lip'] = landmarks_64[76:83]
58
+ boundaries['upper_inner_lip'] = np.array([landmarks_64[i] for i in [88, 89, 90, 91, 92]])
59
+ boundaries['lower_outer_lip'] = np.array([landmarks_64[i] for i in [76, 87, 86, 85, 84, 83, 82]])
60
+ boundaries['lower_inner_lip'] = np.array([landmarks_64[i] for i in [88, 95, 94, 93, 92]])
61
+ elif self.num_landmarks == 19:
62
+ boundaries = {}
63
+ boundaries['left_eyebrow'] = landmarks_64[0:3]
64
+ boundaries['right_eyebrow'] = landmarks_64[3:5]
65
+ boundaries['left_eye'] = landmarks_64[6:9]
66
+ boundaries['right_eye'] = landmarks_64[9:12]
67
+ boundaries['noise'] = landmarks_64[12:15]
68
+
69
+ elif self.num_landmarks == 29:
70
+ boundaries = {}
71
+ boundaries['upper_left_eyebrow'] = np.stack([
72
+ landmarks_64[0],
73
+ landmarks_64[4],
74
+ landmarks_64[2]
75
+ ], axis=0)
76
+ boundaries['lower_left_eyebrow'] = np.stack([
77
+ landmarks_64[0],
78
+ landmarks_64[5],
79
+ landmarks_64[2]
80
+ ], axis=0)
81
+ boundaries['upper_right_eyebrow'] = np.stack([
82
+ landmarks_64[1],
83
+ landmarks_64[6],
84
+ landmarks_64[3]
85
+ ], axis=0)
86
+ boundaries['lower_right_eyebrow'] = np.stack([
87
+ landmarks_64[1],
88
+ landmarks_64[7],
89
+ landmarks_64[3]
90
+ ], axis=0)
91
+ boundaries['upper_left_eye'] = np.stack([
92
+ landmarks_64[8],
93
+ landmarks_64[12],
94
+ landmarks_64[10]
95
+ ], axis=0)
96
+ boundaries['lower_left_eye'] = np.stack([
97
+ landmarks_64[8],
98
+ landmarks_64[13],
99
+ landmarks_64[10]
100
+ ], axis=0)
101
+ boundaries['upper_right_eye'] = np.stack([
102
+ landmarks_64[9],
103
+ landmarks_64[14],
104
+ landmarks_64[11]
105
+ ], axis=0)
106
+ boundaries['lower_right_eye'] = np.stack([
107
+ landmarks_64[9],
108
+ landmarks_64[15],
109
+ landmarks_64[11]
110
+ ], axis=0)
111
+ boundaries['noise'] = np.stack([
112
+ landmarks_64[18],
113
+ landmarks_64[21],
114
+ landmarks_64[19]
115
+ ], axis=0)
116
+ boundaries['outer_upper_lip'] = np.stack([
117
+ landmarks_64[22],
118
+ landmarks_64[24],
119
+ landmarks_64[23]
120
+ ], axis=0)
121
+ boundaries['inner_upper_lip'] = np.stack([
122
+ landmarks_64[22],
123
+ landmarks_64[25],
124
+ landmarks_64[23]
125
+ ], axis=0)
126
+ boundaries['outer_lower_lip'] = np.stack([
127
+ landmarks_64[22],
128
+ landmarks_64[26],
129
+ landmarks_64[23]
130
+ ], axis=0)
131
+ boundaries['inner_lower_lip'] = np.stack([
132
+ landmarks_64[22],
133
+ landmarks_64[27],
134
+ landmarks_64[23]
135
+ ], axis=0)
136
+ functions = {}
137
+
138
+ for key, points in boundaries.items():
139
+ temp = points[0]
140
+ new_points = points[0:1, :]
141
+ for point in points[1:]:
142
+ if point[0] == temp[0] and point[1] == temp[1]:
143
+ continue
144
+ else:
145
+ new_points = np.concatenate((new_points, np.expand_dims(point, 0)), axis=0)
146
+ temp = point
147
+ points = new_points
148
+ if points.shape[0] == 1:
149
+ points = np.concatenate((points, points+0.001), axis=0)
150
+ k = min(4, points.shape[0])
151
+ functions[key] = interpolate.splprep([points[:, 0], points[:, 1]], k=k-1,s=0)
152
+
153
+ boundary_map = np.zeros((64, 64))
154
+
155
+ fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96)
156
+
157
+ ax = fig.add_axes([0, 0, 1, 1])
158
+
159
+ ax.axis('off')
160
+
161
+ ax.imshow(boundary_map, interpolation='nearest', cmap='gray')
162
+ #ax.scatter(landmarks[:, 0], landmarks[:, 1], s=1, marker=',', c='w')
163
+
164
+ for key in functions.keys():
165
+ xnew = np.arange(0, 1, 0.01)
166
+ out = interpolate.splev(xnew, functions[key][0], der=0)
167
+ plt.plot(out[0], out[1], ',', linewidth=1, color='w')
168
+
169
+ img = fig2data(fig)
170
+
171
+ plt.close()
172
+
173
+ sigma = 1
174
+ temp = 255-img[:,:,1]
175
+ temp = cv2.distanceTransform(temp, cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
176
+ temp = temp.astype(np.float32)
177
+ temp = np.where(temp < 3*sigma, np.exp(-(temp*temp)/(2*sigma*sigma)), 0 )
178
+
179
+ fig = plt.figure(figsize=[64/96.0, 64/96.0], dpi=96)
180
+
181
+ ax = fig.add_axes([0, 0, 1, 1])
182
+
183
+ ax.axis('off')
184
+ ax.imshow(temp, cmap='gray')
185
+ plt.close()
186
+
187
+ boundary_map = fig2data(fig)
188
+
189
+ sample['boundary'] = boundary_map[:, :, 0]
190
+
191
+ return sample
192
+
193
+ class AddWeightMap(object):
194
+ def __call__(self, sample):
195
+ heatmap= sample['heatmap']
196
+ boundary = sample['boundary']
197
+ heatmap = np.concatenate((heatmap, np.expand_dims(boundary, axis=0)), 0)
198
+ weight_map = np.zeros_like(heatmap)
199
+ for i in range(heatmap.shape[0]):
200
+ weight_map[i] = generate_weight_map(weight_map[i],
201
+ heatmap[i])
202
+ sample['weight_map'] = weight_map
203
+ return sample
204
+
205
+ class ToTensor(object):
206
+ """Convert ndarrays in sample to Tensors."""
207
+
208
+ def __call__(self, sample):
209
+ image, heatmap, landmarks, boundary, weight_map= sample['image'], sample['heatmap'], sample['landmarks'], sample['boundary'], sample['weight_map']
210
+
211
+ # swap color axis because
212
+ # numpy image: H x W x C
213
+ # torch image: C X H X W
214
+ if len(image.shape) == 2:
215
+ image = np.expand_dims(image, axis=2)
216
+ image_small = np.expand_dims(image_small, axis=2)
217
+ image = image.transpose((2, 0, 1))
218
+ boundary = np.expand_dims(boundary, axis=2)
219
+ boundary = boundary.transpose((2, 0, 1))
220
+ return {'image': torch.from_numpy(image).float().div(255.0),
221
+ 'heatmap': torch.from_numpy(heatmap).float(),
222
+ 'landmarks': torch.from_numpy(landmarks).float(),
223
+ 'boundary': torch.from_numpy(boundary).float().div(255.0),
224
+ 'weight_map': torch.from_numpy(weight_map).float()}
225
+
226
+ class FaceLandmarksDataset(Dataset):
227
+ """Face Landmarks dataset."""
228
+
229
+ def __init__(self, img_dir, landmarks_dir, num_landmarks=68, gray_scale=False,
230
+ detect_face=False, enhance=False, center_shift=0,
231
+ transform=None,):
232
+ """
233
+ Args:
234
+ landmark_dir (string): Path to the mat file with landmarks saved.
235
+ img_dir (string): Directory with all the images.
236
+ transform (callable, optional): Optional transform to be applied
237
+ on a sample.
238
+ """
239
+ self.img_dir = img_dir
240
+ self.landmarks_dir = landmarks_dir
241
+ self.num_lanmdkars = num_landmarks
242
+ self.transform = transform
243
+ self.img_names = glob.glob(self.img_dir+'*.jpg') + \
244
+ glob.glob(self.img_dir+'*.png')
245
+ self.gray_scale = gray_scale
246
+ self.detect_face = detect_face
247
+ self.enhance = enhance
248
+ self.center_shift = center_shift
249
+ if self.detect_face:
250
+ self.face_detector = MTCNN(thresh=[0.5, 0.6, 0.7])
251
+ def __len__(self):
252
+ return len(self.img_names)
253
+
254
+ def __getitem__(self, idx):
255
+ img_name = self.img_names[idx]
256
+ pil_image = Image.open(img_name)
257
+ if pil_image.mode != "RGB":
258
+ # if input is grayscale image, convert it to 3 channel image
259
+ if self.enhance:
260
+ pil_image = power_transform(pil_image, 0.5)
261
+ temp_image = Image.new('RGB', pil_image.size)
262
+ temp_image.paste(pil_image)
263
+ pil_image = temp_image
264
+ image = np.array(pil_image)
265
+ if self.gray_scale:
266
+ image = rgb2gray(image)
267
+ image = np.expand_dims(image, axis=2)
268
+ image = np.concatenate((image, image, image), axis=2)
269
+ image = image * 255.0
270
+ image = image.astype(np.uint8)
271
+ if not self.detect_face:
272
+ center = [450//2, 450//2+0]
273
+ if self.center_shift != 0:
274
+ center[0] += int(np.random.uniform(-self.center_shift,
275
+ self.center_shift))
276
+ center[1] += int(np.random.uniform(-self.center_shift,
277
+ self.center_shift))
278
+ scale = 1.8
279
+ else:
280
+ detected_faces = self.face_detector.detect_image(image)
281
+ if len(detected_faces) > 0:
282
+ box = detected_faces[0]
283
+ left, top, right, bottom, _ = box
284
+ center = [right - (right - left) / 2.0,
285
+ bottom - (bottom - top) / 2.0]
286
+ center[1] = center[1] - (bottom - top) * 0.12
287
+ scale = (right - left + bottom - top) / 195.0
288
+ else:
289
+ center = [450//2, 450//2+0]
290
+ scale = 1.8
291
+ if self.center_shift != 0:
292
+ shift = self.center * self.center_shift / 450
293
+ center[0] += int(np.random.uniform(-shift, shift))
294
+ center[1] += int(np.random.uniform(-shift, shift))
295
+ base_name = os.path.basename(img_name)
296
+ landmarks_base_name = base_name[:-4] + '_pts.mat'
297
+ landmarks_name = os.path.join(self.landmarks_dir, landmarks_base_name)
298
+ if os.path.isfile(landmarks_name):
299
+ mat_data = sio.loadmat(landmarks_name)
300
+ landmarks = mat_data['pts_2d']
301
+ elif os.path.isfile(landmarks_name[:-8] + '.pts.npy'):
302
+ landmarks = np.load(landmarks_name[:-8] + '.pts.npy')
303
+ else:
304
+ landmarks = []
305
+ heatmap = []
306
+
307
+ if landmarks != []:
308
+ new_image, new_landmarks = cv_crop(image, landmarks, center,
309
+ scale, 256, self.center_shift)
310
+ tries = 0
311
+ while self.center_shift != 0 and tries < 5 and (np.max(new_landmarks) > 240 or np.min(new_landmarks) < 15):
312
+ center = [450//2, 450//2+0]
313
+ scale += 0.05
314
+ center[0] += int(np.random.uniform(-self.center_shift,
315
+ self.center_shift))
316
+ center[1] += int(np.random.uniform(-self.center_shift,
317
+ self.center_shift))
318
+
319
+ new_image, new_landmarks = cv_crop(image, landmarks,
320
+ center, scale, 256,
321
+ self.center_shift)
322
+ tries += 1
323
+ if np.max(new_landmarks) > 250 or np.min(new_landmarks) < 5:
324
+ center = [450//2, 450//2+0]
325
+ scale = 2.25
326
+ new_image, new_landmarks = cv_crop(image, landmarks,
327
+ center, scale, 256,
328
+ 100)
329
+ assert (np.min(new_landmarks) > 0 and np.max(new_landmarks) < 256), \
330
+ "Landmarks out of boundary!"
331
+ image = new_image
332
+ landmarks = new_landmarks
333
+ heatmap = np.zeros((self.num_lanmdkars, 64, 64))
334
+ for i in range(self.num_lanmdkars):
335
+ if landmarks[i][0] > 0:
336
+ heatmap[i] = draw_gaussian(heatmap[i], landmarks[i]/4.0+1, 1)
337
+ sample = {'image': image, 'heatmap': heatmap, 'landmarks': landmarks}
338
+ if self.transform:
339
+ sample = self.transform(sample)
340
+
341
+ return sample
342
+
343
+ def get_dataset(val_img_dir, val_landmarks_dir, batch_size,
344
+ num_landmarks=68, rotation=0, scale=0,
345
+ center_shift=0, random_flip=False,
346
+ brightness=0, contrast=0, saturation=0,
347
+ blur=False, noise=False, jpeg_effect=False,
348
+ random_occlusion=False, gray_scale=False,
349
+ detect_face=False, enhance=False):
350
+ val_transforms = transforms.Compose([AddBoundary(num_landmarks),
351
+ AddWeightMap(),
352
+ ToTensor()])
353
+
354
+ val_dataset = FaceLandmarksDataset(val_img_dir, val_landmarks_dir,
355
+ num_landmarks=num_landmarks,
356
+ gray_scale=gray_scale,
357
+ detect_face=detect_face,
358
+ enhance=enhance,
359
+ transform=val_transforms)
360
+
361
+ val_dataloader = torch.utils.data.DataLoader(val_dataset,
362
+ batch_size=batch_size,
363
+ shuffle=False,
364
+ num_workers=6)
365
+ data_loaders = {'val': val_dataloader}
366
+ dataset_sizes = {}
367
+ dataset_sizes['val'] = len(val_dataset)
368
+ return data_loaders, dataset_sizes
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/evaler.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import math
4
+ import torch
5
+ import copy
6
+ import time
7
+ from torch.autograd import Variable
8
+ import shutil
9
+ from skimage import io
10
+ import numpy as np
11
+ from utils.utils import fan_NME, show_landmarks, get_preds_fromhm
12
+ from PIL import Image, ImageDraw
13
+ import os
14
+ import sys
15
+ import cv2
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+
21
+ def eval_model(model, dataloaders, dataset_sizes,
22
+ writer, use_gpu=True, epoches=5, dataset='val',
23
+ save_path='./', num_landmarks=68):
24
+ global_nme = 0
25
+ model.eval()
26
+ for epoch in range(epoches):
27
+ running_loss = 0
28
+ step = 0
29
+ total_nme = 0
30
+ total_count = 0
31
+ fail_count = 0
32
+ nmes = []
33
+ # running_corrects = 0
34
+
35
+ # Iterate over data.
36
+ with torch.no_grad():
37
+ for data in dataloaders[dataset]:
38
+ total_runtime = 0
39
+ run_count = 0
40
+ step_start = time.time()
41
+ step += 1
42
+ # get the inputs
43
+ inputs = data['image'].type(torch.FloatTensor)
44
+ labels_heatmap = data['heatmap'].type(torch.FloatTensor)
45
+ labels_boundary = data['boundary'].type(torch.FloatTensor)
46
+ landmarks = data['landmarks'].type(torch.FloatTensor)
47
+ loss_weight_map = data['weight_map'].type(torch.FloatTensor)
48
+ # wrap them in Variable
49
+ if use_gpu:
50
+ inputs = inputs.to(device)
51
+ labels_heatmap = labels_heatmap.to(device)
52
+ labels_boundary = labels_boundary.to(device)
53
+ loss_weight_map = loss_weight_map.to(device)
54
+ else:
55
+ inputs, labels_heatmap = Variable(inputs), Variable(labels_heatmap)
56
+ labels_boundary = Variable(labels_boundary)
57
+ labels = torch.cat((labels_heatmap, labels_boundary), 1)
58
+ single_start = time.time()
59
+ outputs, boundary_channels = model(inputs)
60
+ single_end = time.time()
61
+ total_runtime += time.time() - single_start
62
+ run_count += 1
63
+ step_end = time.time()
64
+ for i in range(inputs.shape[0]):
65
+ print(inputs.shape)
66
+ img = inputs[i]
67
+ img = img.cpu().numpy()
68
+ img = img.transpose((1, 2, 0)) #*255.0
69
+ # img = img.astype(np.uint8)
70
+ # img = Image.fromarray(img)
71
+ # pred_heatmap = outputs[-1][i].detach().cpu()[:-1, :, :]
72
+ pred_heatmap = outputs[-1][:, :-1, :, :][i].detach().cpu()
73
+ pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
74
+ pred_landmarks = pred_landmarks.squeeze().numpy()
75
+
76
+ gt_landmarks = data['landmarks'][i].numpy()
77
+ print(pred_landmarks, gt_landmarks)
78
+ import cv2
79
+ while(True):
80
+ imgshow = vis_landmark_on_img(cv2.UMat(img), pred_landmarks*4)
81
+ cv2.imshow('img', imgshow)
82
+
83
+ if(cv2.waitKey(10) == ord('q')):
84
+ break
85
+
86
+
87
+ if num_landmarks == 68:
88
+ left_eye = np.average(gt_landmarks[36:42], axis=0)
89
+ right_eye = np.average(gt_landmarks[42:48], axis=0)
90
+ norm_factor = np.linalg.norm(left_eye - right_eye)
91
+ # norm_factor = np.linalg.norm(gt_landmarks[36]- gt_landmarks[45])
92
+
93
+ elif num_landmarks == 98:
94
+ norm_factor = np.linalg.norm(gt_landmarks[60]- gt_landmarks[72])
95
+ elif num_landmarks == 19:
96
+ left, top = gt_landmarks[-2, :]
97
+ right, bottom = gt_landmarks[-1, :]
98
+ norm_factor = math.sqrt(abs(right - left)*abs(top-bottom))
99
+ gt_landmarks = gt_landmarks[:-2, :]
100
+ elif num_landmarks == 29:
101
+ # norm_factor = np.linalg.norm(gt_landmarks[8]- gt_landmarks[9])
102
+ norm_factor = np.linalg.norm(gt_landmarks[16]- gt_landmarks[17])
103
+ single_nme = (np.sum(np.linalg.norm(pred_landmarks*4 - gt_landmarks, axis=1)) / pred_landmarks.shape[0]) / norm_factor
104
+
105
+ nmes.append(single_nme)
106
+ total_count += 1
107
+ if single_nme > 0.1:
108
+ fail_count += 1
109
+ if step % 10 == 0:
110
+ print('Step {} Time: {:.6f} Input Mean: {:.6f} Output Mean: {:.6f}'.format(
111
+ step, step_end - step_start,
112
+ torch.mean(labels),
113
+ torch.mean(outputs[0])))
114
+ # gt_landmarks = landmarks.numpy()
115
+ # pred_heatmap = outputs[-1].to('cpu').numpy()
116
+ gt_landmarks = landmarks
117
+ batch_nme = fan_NME(outputs[-1][:, :-1, :, :].detach().cpu(), gt_landmarks, num_landmarks)
118
+ # batch_nme = 0
119
+ total_nme += batch_nme
120
+ epoch_nme = total_nme / dataset_sizes['val']
121
+ global_nme += epoch_nme
122
+ nme_save_path = os.path.join(save_path, 'nme_log.npy')
123
+ np.save(nme_save_path, np.array(nmes))
124
+ print('NME: {:.6f} Failure Rate: {:.6f} Total Count: {:.6f} Fail Count: {:.6f}'.format(epoch_nme, fail_count/total_count, total_count, fail_count))
125
+ print('Evaluation done! Average NME: {:.6f}'.format(global_nme/epoches))
126
+ print('Everage runtime for a single batch: {:.6f}'.format(total_runtime/run_count))
127
+ return model
128
+
129
+
130
+ def vis_landmark_on_img(img, shape, linewidth=2):
131
+ '''
132
+ Visualize landmark on images.
133
+ '''
134
+
135
+ def draw_curve(idx_list, color=(0, 255, 0), loop=False, lineWidth=linewidth):
136
+ for i in idx_list:
137
+ cv2.line(img, (shape[i, 0], shape[i, 1]), (shape[i + 1, 0], shape[i + 1, 1]), color, lineWidth)
138
+ if (loop):
139
+ cv2.line(img, (shape[idx_list[0], 0], shape[idx_list[0], 1]),
140
+ (shape[idx_list[-1] + 1, 0], shape[idx_list[-1] + 1, 1]), color, lineWidth)
141
+
142
+ draw_curve(list(range(0, 32))) # jaw
143
+ draw_curve(list(range(33, 41)), color=(0, 0, 255), loop=True) # eye brow
144
+ draw_curve(list(range(42, 50)), color=(0, 0, 255), loop=True)
145
+ draw_curve(list(range(51, 59))) # nose
146
+ draw_curve(list(range(60, 67)), loop=True) # eyes
147
+ draw_curve(list(range(68, 75)), loop=True)
148
+ draw_curve(list(range(76, 87)), loop=True, color=(0, 255, 255)) # mouth
149
+ draw_curve(list(range(88, 95)), loop=True, color=(255, 255, 0))
150
+
151
+ return img
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/core/models.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from core.coord_conv import CoordConvTh
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, strd=1, padding=1,
9
+ bias=False,dilation=1):
10
+ "3x3 convolution with padding"
11
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
12
+ stride=strd, padding=padding, bias=bias,
13
+ dilation=dilation)
14
+
15
+ class BasicBlock(nn.Module):
16
+ expansion = 1
17
+
18
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
19
+ super(BasicBlock, self).__init__()
20
+ self.conv1 = conv3x3(inplanes, planes, stride)
21
+ # self.bn1 = nn.BatchNorm2d(planes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+ self.conv2 = conv3x3(planes, planes)
24
+ # self.bn2 = nn.BatchNorm2d(planes)
25
+ self.downsample = downsample
26
+ self.stride = stride
27
+
28
+ def forward(self, x):
29
+ residual = x
30
+
31
+ out = self.conv1(x)
32
+ # out = self.bn1(out)
33
+ out = self.relu(out)
34
+
35
+ out = self.conv2(out)
36
+ # out = self.bn2(out)
37
+
38
+ if self.downsample is not None:
39
+ residual = self.downsample(x)
40
+
41
+ out += residual
42
+ out = self.relu(out)
43
+
44
+ return out
45
+
46
+ class ConvBlock(nn.Module):
47
+ def __init__(self, in_planes, out_planes):
48
+ super(ConvBlock, self).__init__()
49
+ self.bn1 = nn.BatchNorm2d(in_planes)
50
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
51
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
52
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4),
53
+ padding=1, dilation=1)
54
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
55
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4),
56
+ padding=1, dilation=1)
57
+
58
+ if in_planes != out_planes:
59
+ self.downsample = nn.Sequential(
60
+ nn.BatchNorm2d(in_planes),
61
+ nn.ReLU(True),
62
+ nn.Conv2d(in_planes, out_planes,
63
+ kernel_size=1, stride=1, bias=False),
64
+ )
65
+ else:
66
+ self.downsample = None
67
+
68
+ def forward(self, x):
69
+ residual = x
70
+
71
+ out1 = self.bn1(x)
72
+ out1 = F.relu(out1, True)
73
+ out1 = self.conv1(out1)
74
+
75
+ out2 = self.bn2(out1)
76
+ out2 = F.relu(out2, True)
77
+ out2 = self.conv2(out2)
78
+
79
+ out3 = self.bn3(out2)
80
+ out3 = F.relu(out3, True)
81
+ out3 = self.conv3(out3)
82
+
83
+ out3 = torch.cat((out1, out2, out3), 1)
84
+
85
+ if self.downsample is not None:
86
+ residual = self.downsample(residual)
87
+
88
+ out3 += residual
89
+
90
+ return out3
91
+
92
+ class HourGlass(nn.Module):
93
+ def __init__(self, num_modules, depth, num_features, first_one=False):
94
+ super(HourGlass, self).__init__()
95
+ self.num_modules = num_modules
96
+ self.depth = depth
97
+ self.features = num_features
98
+ self.coordconv = CoordConvTh(x_dim=64, y_dim=64,
99
+ with_r=True, with_boundary=True,
100
+ in_channels=256, first_one=first_one,
101
+ out_channels=256,
102
+ kernel_size=1,
103
+ stride=1, padding=0)
104
+ self._generate_network(self.depth)
105
+
106
+ def _generate_network(self, level):
107
+ self.add_module('b1_' + str(level), ConvBlock(256, 256))
108
+
109
+ self.add_module('b2_' + str(level), ConvBlock(256, 256))
110
+
111
+ if level > 1:
112
+ self._generate_network(level - 1)
113
+ else:
114
+ self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
115
+
116
+ self.add_module('b3_' + str(level), ConvBlock(256, 256))
117
+
118
+ def _forward(self, level, inp):
119
+ # Upper branch
120
+ up1 = inp
121
+ up1 = self._modules['b1_' + str(level)](up1)
122
+
123
+ # Lower branch
124
+ low1 = F.avg_pool2d(inp, 2, stride=2)
125
+ low1 = self._modules['b2_' + str(level)](low1)
126
+
127
+ if level > 1:
128
+ low2 = self._forward(level - 1, low1)
129
+ else:
130
+ low2 = low1
131
+ low2 = self._modules['b2_plus_' + str(level)](low2)
132
+
133
+ low3 = low2
134
+ low3 = self._modules['b3_' + str(level)](low3)
135
+
136
+ up2 = F.upsample(low3, scale_factor=2, mode='nearest')
137
+
138
+ return up1 + up2
139
+
140
+ def forward(self, x, heatmap):
141
+ x, last_channel = self.coordconv(x, heatmap)
142
+ return self._forward(self.depth, x), last_channel
143
+
144
+ class FAN(nn.Module):
145
+
146
+ def __init__(self, num_modules=1, end_relu=False, gray_scale=False,
147
+ num_landmarks=68):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+ self.gray_scale = gray_scale
151
+ self.end_relu = end_relu
152
+ self.num_landmarks = num_landmarks
153
+
154
+ # Base part
155
+ if self.gray_scale:
156
+ self.conv1 = CoordConvTh(x_dim=256, y_dim=256,
157
+ with_r=True, with_boundary=False,
158
+ in_channels=3, out_channels=64,
159
+ kernel_size=7,
160
+ stride=2, padding=3)
161
+ else:
162
+ self.conv1 = CoordConvTh(x_dim=256, y_dim=256,
163
+ with_r=True, with_boundary=False,
164
+ in_channels=3, out_channels=64,
165
+ kernel_size=7,
166
+ stride=2, padding=3)
167
+ self.bn1 = nn.BatchNorm2d(64)
168
+ self.conv2 = ConvBlock(64, 128)
169
+ self.conv3 = ConvBlock(128, 128)
170
+ self.conv4 = ConvBlock(128, 256)
171
+
172
+ # Stacking part
173
+ for hg_module in range(self.num_modules):
174
+ if hg_module == 0:
175
+ first_one = True
176
+ else:
177
+ first_one = False
178
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256,
179
+ first_one))
180
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
181
+ self.add_module('conv_last' + str(hg_module),
182
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
183
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
184
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
185
+ num_landmarks+1, kernel_size=1, stride=1, padding=0))
186
+
187
+ if hg_module < self.num_modules - 1:
188
+ self.add_module(
189
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
190
+ self.add_module('al' + str(hg_module), nn.Conv2d(num_landmarks+1,
191
+ 256, kernel_size=1, stride=1, padding=0))
192
+
193
+ def forward(self, x):
194
+ x, _ = self.conv1(x)
195
+ x = F.relu(self.bn1(x), True)
196
+ # x = F.relu(self.bn1(self.conv1(x)), True)
197
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
198
+ x = self.conv3(x)
199
+ x = self.conv4(x)
200
+
201
+ previous = x
202
+
203
+ outputs = []
204
+ boundary_channels = []
205
+ tmp_out = None
206
+ for i in range(self.num_modules):
207
+ hg, boundary_channel = self._modules['m' + str(i)](previous,
208
+ tmp_out)
209
+
210
+ ll = hg
211
+ ll = self._modules['top_m_' + str(i)](ll)
212
+
213
+ ll = F.relu(self._modules['bn_end' + str(i)]
214
+ (self._modules['conv_last' + str(i)](ll)), True)
215
+
216
+ # Predict heatmaps
217
+ tmp_out = self._modules['l' + str(i)](ll)
218
+ if self.end_relu:
219
+ tmp_out = F.relu(tmp_out) # HACK: Added relu
220
+ outputs.append(tmp_out)
221
+ boundary_channels.append(boundary_channel)
222
+
223
+ if i < self.num_modules - 1:
224
+ ll = self._modules['bl' + str(i)](ll)
225
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
226
+ previous = previous + ll + tmp_out_
227
+
228
+ return outputs, boundary_channels
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/eval.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import torch
3
+ import argparse
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import time
7
+ import os
8
+ from core.evaler import eval_model
9
+ from core.dataloader import get_dataset
10
+ from core import models
11
+ from tensorboardX import SummaryWriter
12
+
13
+ # Parse arguments
14
+ parser = argparse.ArgumentParser()
15
+ # Dataset paths
16
+ parser.add_argument('--val_img_dir', type=str,
17
+ help='Validation image directory')
18
+ parser.add_argument('--val_landmarks_dir', type=str,
19
+ help='Validation landmarks directory')
20
+ parser.add_argument('--num_landmarks', type=int, default=68,
21
+ help='Number of landmarks')
22
+
23
+ # Checkpoint and pretrained weights
24
+ parser.add_argument('--ckpt_save_path', type=str,
25
+ help='a directory to save checkpoint file')
26
+ parser.add_argument('--pretrained_weights', type=str,
27
+ help='a directory to save pretrained_weights')
28
+
29
+ # Eval options
30
+ parser.add_argument('--batch_size', type=int, default=25,
31
+ help='learning rate decay after each epoch')
32
+
33
+ # Network parameters
34
+ parser.add_argument('--hg_blocks', type=int, default=4,
35
+ help='Number of HG blocks to stack')
36
+ parser.add_argument('--gray_scale', type=str, default="False",
37
+ help='Whether to convert RGB image into gray scale during training')
38
+ parser.add_argument('--end_relu', type=str, default="False",
39
+ help='Whether to add relu at the end of each HG module')
40
+
41
+ args = parser.parse_args()
42
+
43
+ VAL_IMG_DIR = args.val_img_dir
44
+ VAL_LANDMARKS_DIR = args.val_landmarks_dir
45
+ CKPT_SAVE_PATH = args.ckpt_save_path
46
+ BATCH_SIZE = args.batch_size
47
+ PRETRAINED_WEIGHTS = args.pretrained_weights
48
+ GRAY_SCALE = False if args.gray_scale == 'False' else True
49
+ HG_BLOCKS = args.hg_blocks
50
+ END_RELU = False if args.end_relu == 'False' else True
51
+ NUM_LANDMARKS = args.num_landmarks
52
+
53
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
54
+
55
+ writer = SummaryWriter(CKPT_SAVE_PATH)
56
+
57
+ dataloaders, dataset_sizes = get_dataset(VAL_IMG_DIR, VAL_LANDMARKS_DIR,
58
+ BATCH_SIZE, NUM_LANDMARKS)
59
+ use_gpu = torch.cuda.is_available()
60
+ model_ft = models.FAN(HG_BLOCKS, END_RELU, GRAY_SCALE, NUM_LANDMARKS)
61
+
62
+ if PRETRAINED_WEIGHTS != "None":
63
+ checkpoint = torch.load(PRETRAINED_WEIGHTS)
64
+ if 'state_dict' not in checkpoint:
65
+ model_ft.load_state_dict(checkpoint)
66
+ else:
67
+ pretrained_weights = checkpoint['state_dict']
68
+ model_weights = model_ft.state_dict()
69
+ pretrained_weights = {k: v for k, v in pretrained_weights.items() \
70
+ if k in model_weights}
71
+ model_weights.update(pretrained_weights)
72
+ model_ft.load_state_dict(model_weights)
73
+
74
+ model_ft = model_ft.to(device)
75
+
76
+ model_ft = eval_model(model_ft, dataloaders, dataset_sizes, writer, use_gpu, 1, 'val', CKPT_SAVE_PATH, NUM_LANDMARKS)
77
+
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw.png ADDED

Git LFS Details

  • SHA256: 354babe46beeec86fc8a9f64c57a1dad0ec19ff23f455ac3405321bab473ce23
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/images/wflw_table.png ADDED

Git LFS Details

  • SHA256: 87c9ea0af4854681b6fc5e911ac38042ca5099098146501f20b64a6457a9d98b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ scipy>=0.17.0
3
+ scikit-image
4
+ numpy
5
+ matplotlib
6
+ Pillow>=4.3.0
7
+ imgaug
8
+ tensorflow
9
+ git+https://github.com/lanpa/tensorboardX
10
+ joblib
11
+ torch==1.3.0
12
+ torchvision==0.4.1
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/scripts/eval_wflw.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=1 python ../eval.py \
2
+ --val_img_dir='../dataset/WFLW_test/images/' \
3
+ --val_landmarks_dir='../dataset/WFLW_test/landmarks/' \
4
+ --ckpt_save_path='../experiments/eval_iccv_0620' \
5
+ --hg_blocks=4 \
6
+ --pretrained_weights='../ckpt/WFLW_4HG.pth' \
7
+ --num_landmarks=98 \
8
+ --end_relu='False' \
9
+ --batch_size=20 \
10
+
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__init__.py ADDED
File without changes
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (170 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (185 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-37.pyc ADDED
Binary file (11.8 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (11.6 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/AdaptiveWingLoss/utils/utils.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ import sys
4
+ import math
5
+ import torch
6
+ import cv2
7
+ from PIL import Image
8
+ from skimage import io
9
+ from skimage import transform as ski_transform
10
+ from scipy import ndimage
11
+ import numpy as np
12
+ import matplotlib
13
+ import matplotlib.pyplot as plt
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torchvision import transforms, utils
16
+
17
+ def _gaussian(
18
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
19
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
20
+ mean_vert=0.5):
21
+ # handle some defaults
22
+ if width is None:
23
+ width = size
24
+ if height is None:
25
+ height = size
26
+ if sigma_horz is None:
27
+ sigma_horz = sigma
28
+ if sigma_vert is None:
29
+ sigma_vert = sigma
30
+ center_x = mean_horz * width + 0.5
31
+ center_y = mean_vert * height + 0.5
32
+ gauss = np.empty((height, width), dtype=np.float32)
33
+ # generate kernel
34
+ for i in range(height):
35
+ for j in range(width):
36
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
37
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
38
+ if normalize:
39
+ gauss = gauss / np.sum(gauss)
40
+ return gauss
41
+
42
+ def draw_gaussian(image, point, sigma):
43
+ # Check if the gaussian is inside
44
+ ul = [np.floor(np.floor(point[0]) - 3 * sigma),
45
+ np.floor(np.floor(point[1]) - 3 * sigma)]
46
+ br = [np.floor(np.floor(point[0]) + 3 * sigma),
47
+ np.floor(np.floor(point[1]) + 3 * sigma)]
48
+ if (ul[0] > image.shape[1] or ul[1] >
49
+ image.shape[0] or br[0] < 1 or br[1] < 1):
50
+ return image
51
+ size = 6 * sigma + 1
52
+ g = _gaussian(size)
53
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
54
+ int(max(1, ul[0])) + int(max(1, -ul[0]))]
55
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
56
+ int(max(1, ul[1])) + int(max(1, -ul[1]))]
57
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
58
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
59
+ assert (g_x[0] > 0 and g_y[1] > 0)
60
+ correct = False
61
+ while not correct:
62
+ try:
63
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
64
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
65
+ correct = True
66
+ except:
67
+ print('img_x: {}, img_y: {}, g_x:{}, g_y:{}, point:{}, g_shape:{}, ul:{}, br:{}'.format(img_x, img_y, g_x, g_y, point, g.shape, ul, br))
68
+ ul = [np.floor(np.floor(point[0]) - 3 * sigma),
69
+ np.floor(np.floor(point[1]) - 3 * sigma)]
70
+ br = [np.floor(np.floor(point[0]) + 3 * sigma),
71
+ np.floor(np.floor(point[1]) + 3 * sigma)]
72
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) -
73
+ int(max(1, ul[0])) + int(max(1, -ul[0]))]
74
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) -
75
+ int(max(1, ul[1])) + int(max(1, -ul[1]))]
76
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
77
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
78
+ pass
79
+ image[image > 1] = 1
80
+ return image
81
+
82
+ def transform(point, center, scale, resolution, rotation=0, invert=False):
83
+ _pt = np.ones(3)
84
+ _pt[0] = point[0]
85
+ _pt[1] = point[1]
86
+
87
+ h = 200.0 * scale
88
+ t = np.eye(3)
89
+ t[0, 0] = resolution / h
90
+ t[1, 1] = resolution / h
91
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
92
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
93
+
94
+ if rotation != 0:
95
+ rotation = -rotation
96
+ r = np.eye(3)
97
+ ang = rotation * math.pi / 180.0
98
+ s = math.sin(ang)
99
+ c = math.cos(ang)
100
+ r[0][0] = c
101
+ r[0][1] = -s
102
+ r[1][0] = s
103
+ r[1][1] = c
104
+
105
+ t_ = np.eye(3)
106
+ t_[0][2] = -resolution / 2.0
107
+ t_[1][2] = -resolution / 2.0
108
+ t_inv = torch.eye(3)
109
+ t_inv[0][2] = resolution / 2.0
110
+ t_inv[1][2] = resolution / 2.0
111
+ t = reduce(np.matmul, [t_inv, r, t_, t])
112
+
113
+ if invert:
114
+ t = np.linalg.inv(t)
115
+ new_point = (np.matmul(t, _pt))[0:2]
116
+
117
+ return new_point.astype(int)
118
+
119
+ def cv_crop(image, landmarks, center, scale, resolution=256, center_shift=0):
120
+ new_image = cv2.copyMakeBorder(image, center_shift,
121
+ center_shift,
122
+ center_shift,
123
+ center_shift,
124
+ cv2.BORDER_CONSTANT, value=[0,0,0])
125
+ new_landmarks = landmarks.copy()
126
+ if center_shift != 0:
127
+ center[0] += center_shift
128
+ center[1] += center_shift
129
+ new_landmarks = new_landmarks + center_shift
130
+ length = 200 * scale
131
+ top = int(center[1] - length // 2)
132
+ bottom = int(center[1] + length // 2)
133
+ left = int(center[0] - length // 2)
134
+ right = int(center[0] + length // 2)
135
+ y_pad = abs(min(top, new_image.shape[0] - bottom, 0))
136
+ x_pad = abs(min(left, new_image.shape[1] - right, 0))
137
+ top, bottom, left, right = top + y_pad, bottom + y_pad, left + x_pad, right + x_pad
138
+ new_image = cv2.copyMakeBorder(new_image, y_pad,
139
+ y_pad,
140
+ x_pad,
141
+ x_pad,
142
+ cv2.BORDER_CONSTANT, value=[0,0,0])
143
+ new_image = new_image[top:bottom, left:right]
144
+ new_image = cv2.resize(new_image, dsize=(int(resolution), int(resolution)),
145
+ interpolation=cv2.INTER_LINEAR)
146
+ new_landmarks[:, 0] = (new_landmarks[:, 0] + x_pad - left) * resolution / length
147
+ new_landmarks[:, 1] = (new_landmarks[:, 1] + y_pad - top) * resolution / length
148
+ return new_image, new_landmarks
149
+
150
+ def cv_rotate(image, landmarks, heatmap, rot, scale, resolution=256):
151
+ img_mat = cv2.getRotationMatrix2D((resolution//2, resolution//2), rot, scale)
152
+ ones = np.ones(shape=(landmarks.shape[0], 1))
153
+ stacked_landmarks = np.hstack([landmarks, ones])
154
+ new_landmarks = img_mat.dot(stacked_landmarks.T).T
155
+ if np.max(new_landmarks) > 255 or np.min(new_landmarks) < 0:
156
+ return image, landmarks, heatmap
157
+ else:
158
+ new_image = cv2.warpAffine(image, img_mat, (resolution, resolution))
159
+ if heatmap is not None:
160
+ new_heatmap = np.zeros((heatmap.shape[0], 64, 64))
161
+ for i in range(heatmap.shape[0]):
162
+ if new_landmarks[i][0] > 0:
163
+ new_heatmap[i] = draw_gaussian(new_heatmap[i],
164
+ new_landmarks[i]/4.0+1, 1)
165
+ return new_image, new_landmarks, new_heatmap
166
+
167
+ def show_landmarks(image, heatmap, gt_landmarks, gt_heatmap):
168
+ """Show image with pred_landmarks"""
169
+ pred_landmarks = []
170
+ pred_landmarks, _ = get_preds_fromhm(torch.from_numpy(heatmap).unsqueeze(0))
171
+ pred_landmarks = pred_landmarks.squeeze()*4
172
+
173
+ # pred_landmarks2 = get_preds_fromhm2(heatmap)
174
+ heatmap = np.max(gt_heatmap, axis=0)
175
+ heatmap = heatmap / np.max(heatmap)
176
+ # image = ski_transform.resize(image, (64, 64))*255
177
+ image = image.astype(np.uint8)
178
+ heatmap = np.max(gt_heatmap, axis=0)
179
+ heatmap = ski_transform.resize(heatmap, (image.shape[0], image.shape[1]))
180
+ heatmap *= 255
181
+ heatmap = heatmap.astype(np.uint8)
182
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
183
+ plt.imshow(image)
184
+ plt.scatter(gt_landmarks[:, 0], gt_landmarks[:, 1], s=0.5, marker='.', c='g')
185
+ plt.scatter(pred_landmarks[:, 0], pred_landmarks[:, 1], s=0.5, marker='.', c='r')
186
+ plt.pause(0.001) # pause a bit so that plots are updated
187
+
188
+ def fan_NME(pred_heatmaps, gt_landmarks, num_landmarks=68):
189
+ '''
190
+ Calculate total NME for a batch of data
191
+
192
+ Args:
193
+ pred_heatmaps: torch tensor of size [batch, points, height, width]
194
+ gt_landmarks: torch tesnsor of size [batch, points, x, y]
195
+
196
+ Returns:
197
+ nme: sum of nme for this batch
198
+ '''
199
+ nme = 0
200
+ pred_landmarks, _ = get_preds_fromhm(pred_heatmaps)
201
+ pred_landmarks = pred_landmarks.numpy()
202
+ gt_landmarks = gt_landmarks.numpy()
203
+ for i in range(pred_landmarks.shape[0]):
204
+ pred_landmark = pred_landmarks[i] * 4.0
205
+ gt_landmark = gt_landmarks[i]
206
+
207
+ if num_landmarks == 68:
208
+ left_eye = np.average(gt_landmark[36:42], axis=0)
209
+ right_eye = np.average(gt_landmark[42:48], axis=0)
210
+ norm_factor = np.linalg.norm(left_eye - right_eye)
211
+ # norm_factor = np.linalg.norm(gt_landmark[36]- gt_landmark[45])
212
+ elif num_landmarks == 98:
213
+ norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72])
214
+ elif num_landmarks == 19:
215
+ left, top = gt_landmark[-2, :]
216
+ right, bottom = gt_landmark[-1, :]
217
+ norm_factor = math.sqrt(abs(right - left)*abs(top-bottom))
218
+ gt_landmark = gt_landmark[:-2, :]
219
+ elif num_landmarks == 29:
220
+ # norm_factor = np.linalg.norm(gt_landmark[8]- gt_landmark[9])
221
+ norm_factor = np.linalg.norm(gt_landmark[16]- gt_landmark[17])
222
+ nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
223
+ return nme
224
+
225
+ def fan_NME_hm(pred_heatmaps, gt_heatmaps, num_landmarks=68):
226
+ '''
227
+ Calculate total NME for a batch of data
228
+
229
+ Args:
230
+ pred_heatmaps: torch tensor of size [batch, points, height, width]
231
+ gt_landmarks: torch tesnsor of size [batch, points, x, y]
232
+
233
+ Returns:
234
+ nme: sum of nme for this batch
235
+ '''
236
+ nme = 0
237
+ pred_landmarks, _ = get_index_fromhm(pred_heatmaps)
238
+ pred_landmarks = pred_landmarks.numpy()
239
+ gt_landmarks = gt_landmarks.numpy()
240
+ for i in range(pred_landmarks.shape[0]):
241
+ pred_landmark = pred_landmarks[i] * 4.0
242
+ gt_landmark = gt_landmarks[i]
243
+ if num_landmarks == 68:
244
+ left_eye = np.average(gt_landmark[36:42], axis=0)
245
+ right_eye = np.average(gt_landmark[42:48], axis=0)
246
+ norm_factor = np.linalg.norm(left_eye - right_eye)
247
+ else:
248
+ norm_factor = np.linalg.norm(gt_landmark[60]- gt_landmark[72])
249
+ nme += (np.sum(np.linalg.norm(pred_landmark - gt_landmark, axis=1)) / pred_landmark.shape[0]) / norm_factor
250
+ return nme
251
+
252
+ def power_transform(img, power):
253
+ img = np.array(img)
254
+ img_new = np.power((img/255.0), power) * 255.0
255
+ img_new = img_new.astype(np.uint8)
256
+ img_new = Image.fromarray(img_new)
257
+ return img_new
258
+
259
+ def get_preds_fromhm(hm, center=None, scale=None, rot=None):
260
+ max, idx = torch.max(
261
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
262
+ idx += 1
263
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
264
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
265
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
266
+
267
+ for i in range(preds.size(0)):
268
+ for j in range(preds.size(1)):
269
+ hm_ = hm[i, j, :]
270
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
271
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
272
+ diff = torch.FloatTensor(
273
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
274
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
275
+ preds[i, j].add_(diff.sign_().mul_(.25))
276
+
277
+ preds.add_(-0.5)
278
+
279
+ preds_orig = torch.zeros(preds.size())
280
+ if center is not None and scale is not None:
281
+ for i in range(hm.size(0)):
282
+ for j in range(hm.size(1)):
283
+ preds_orig[i, j] = transform(
284
+ preds[i, j], center, scale, hm.size(2), rot, True)
285
+
286
+ return preds, preds_orig
287
+
288
+ def get_index_fromhm(hm):
289
+ max, idx = torch.max(
290
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
291
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
292
+ preds[..., 0].remainder_(hm.size(3))
293
+ preds[..., 1].div_(hm.size(2)).floor_()
294
+
295
+ for i in range(preds.size(0)):
296
+ for j in range(preds.size(1)):
297
+ hm_ = hm[i, j, :]
298
+ pX, pY = int(preds[i, j, 0]), int(preds[i, j, 1])
299
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
300
+ diff = torch.FloatTensor(
301
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
302
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
303
+ preds[i, j].add_(diff.sign_().mul_(.25))
304
+
305
+ return preds
306
+
307
+ def shuffle_lr(parts, num_landmarks=68, pairs=None):
308
+ if num_landmarks == 68:
309
+ if pairs is None:
310
+ pairs = [[0, 16], [1, 15], [2, 14], [3, 13], [4, 12], [5, 11], [6, 10],
311
+ [7, 9], [17, 26], [18, 25], [19, 24], [20, 23], [21, 22], [36, 45],
312
+ [37, 44], [38, 43], [39, 42], [41, 46], [40, 47], [31, 35], [32, 34],
313
+ [50, 52], [49, 53], [48, 54], [61, 63], [60, 64], [67, 65], [59, 55], [58, 56]]
314
+ elif num_landmarks == 98:
315
+ if pairs is None:
316
+ pairs = [[0, 32], [1,31], [2, 30], [3, 29], [4, 28], [5, 27], [6, 26], [7, 25], [8, 24], [9, 23], [10, 22], [11, 21], [12, 20], [13, 19], [14, 18], [15, 17], [33, 46], [34, 45], [35, 44], [36, 43], [37, 42], [38, 50], [39, 49], [40, 48], [41, 47], [60, 72], [61, 71], [62, 70], [63, 69], [64, 68], [65, 75], [66, 74], [67, 73], [96, 97], [55, 59], [56, 58], [76, 82], [77, 81], [78, 80], [88, 92], [89, 91], [95, 93], [87, 83], [86, 84]]
317
+ elif num_landmarks == 19:
318
+ if pairs is None:
319
+ pairs = [[0, 5], [1, 4], [2, 3], [6, 11], [7, 10], [8, 9], [12, 14], [15, 17]]
320
+ elif num_landmarks == 29:
321
+ if pairs is None:
322
+ pairs = [[0, 1], [4, 6], [5, 7], [2, 3], [8, 9], [12, 14], [16, 17], [13, 15], [10, 11], [18, 19], [22, 23]]
323
+ for matched_p in pairs:
324
+ idx1, idx2 = matched_p[0], matched_p[1]
325
+ tmp = np.copy(parts[idx1])
326
+ np.copyto(parts[idx1], parts[idx2])
327
+ np.copyto(parts[idx2], tmp)
328
+ return parts
329
+
330
+
331
+ def generate_weight_map(weight_map,heatmap):
332
+
333
+ k_size = 3
334
+ dilate = ndimage.grey_dilation(heatmap ,size=(k_size,k_size))
335
+ weight_map[np.where(dilate>0.2)] = 1
336
+ return weight_map
337
+
338
+ def fig2data(fig):
339
+ """
340
+ @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
341
+ @param fig a matplotlib figure
342
+ @return a numpy 3D array of RGBA values
343
+ """
344
+ # draw the renderer
345
+ fig.canvas.draw ( )
346
+
347
+ # Get the RGB buffer from the figure
348
+ w,h = fig.canvas.get_width_height()
349
+ buf = np.fromstring (fig.canvas.tostring_rgb(), dtype=np.uint8)
350
+ buf.shape = (w, h, 3)
351
+
352
+ # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
353
+ buf = np.roll (buf, 3, axis=2)
354
+ return buf
marlenezw/audio-driven-animations/MakeItTalk/__init__.py ADDED
File without changes
marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (147 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (162 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ * @papulke
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/LICENCE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Jordan Yaniv
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
18
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
19
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
21
+ OR OTHER DEALINGS IN THE SOFTWARE.
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The Face of Art: Landmark Detection and Geometric Style in Portraits
2
+
3
+ Code for the landmark detection framework described in [The Face of Art: Landmark Detection and Geometric Style in Portraits](http://www.faculty.idc.ac.il/arik/site/foa/face-of-art.asp) (SIGGRAPH 2019)
4
+
5
+ ![](old/teaser.png)
6
+ <sub><sup>Top: landmark detection results on artistic portraits with different styles allows to define the geometric style of an artist. Bottom: results of the style transfer of portraits using various artists' geometric style, including Amedeo Modigliani, Pablo Picasso, Margaret Keane, Fernand Léger, and Tsuguharu Foujita. Top right portrait is from 'Woman with Peanuts,' ©1962, Estate of Roy Lichtenstein.</sup></sub>
7
+
8
+ ## Getting Started
9
+
10
+ ### Requirements
11
+
12
+ * python
13
+ * anaconda
14
+
15
+ ### Download
16
+
17
+ #### Model
18
+ download model weights from [here](https://www.dropbox.com/sh/hrxcyug1bmbj6cs/AAAxq_zI5eawcLjM8zvUwaXha?dl=0).
19
+
20
+ #### Datasets
21
+ * The datasets used for training and evaluating our model can be found [here](https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/).
22
+
23
+ * The Artistic-Faces dataset can be found [here](http://www.faculty.idc.ac.il/arik/site/foa/artistic-faces-dataset.asp).
24
+
25
+ * Training images with texture augmentation can be found [here](https://www.dropbox.com/sh/av2k1i1082z0nie/AAC5qV1E2UkqpDLVsv7TazMta?dl=0).
26
+ before applying texture style transfer, the training images were cropped to the ground-truth face bounding-box with 25% margin. To crop training images, run the script `crop_training_set.py`.
27
+
28
+ * our model expects the following directory structure of landmark detection datasets:
29
+ ```
30
+ landmark_detection_datasets
31
+ ├── training
32
+ ├── test
33
+ ├── challenging
34
+ ├── common
35
+ ├── full
36
+ ├── crop_gt_margin_0.25 (cropped images of training set)
37
+ └── crop_gt_margin_0.25_ns (cropped images of training set + texture style transfer)
38
+ ```
39
+ ### Install
40
+
41
+ Create a virtual environment and install the following:
42
+ * opencv
43
+ * menpo
44
+ * menpofit
45
+ * tensorflow-gpu
46
+
47
+ for python 2:
48
+ ```
49
+ conda create -n foa_env python=2.7 anaconda
50
+ source activate foa_env
51
+ conda install -c menpo opencv
52
+ conda install -c menpo menpo
53
+ conda install -c menpo menpofit
54
+ pip install tensorflow-gpu
55
+
56
+ ```
57
+
58
+ for python 3:
59
+ ```
60
+ conda create -n foa_env python=3.5 anaconda
61
+ source activate foa_env
62
+ conda install -c menpo opencv
63
+ conda install -c menpo menpo
64
+ conda install -c menpo menpofit
65
+ pip3 install tensorflow-gpu
66
+
67
+ ```
68
+
69
+ Clone repository:
70
+
71
+ ```
72
+ git clone https://github.com/papulke/deep_face_heatmaps
73
+ ```
74
+
75
+ ## Instructions
76
+
77
+ ### Training
78
+
79
+ To train the network you need to run `train_heatmaps_network.py`
80
+
81
+ example for training a model with texture augmentation (100% of images) and geometric augmentation (~70% of images):
82
+ ```
83
+ python train_heatmaps_network.py --output_dir='test_artistic_aug' --augment_geom=True \
84
+ --augment_texture=True --p_texture=1. --p_geom=0.7
85
+ ```
86
+
87
+ ### Testing
88
+
89
+ For using the detection framework to predict landmarks, run the script `predict_landmarks.py`
90
+
91
+ ## Acknowledgments
92
+
93
+ * [ect](https://github.com/HongwenZhang/ECT-FaceAlignment)
94
+ * [menpo](https://github.com/menpo/menpo)
95
+ * [menpofit](https://github.com/menpo/menpofit)
96
+ * [mdm](https://github.com/trigeorgis/mdm)
97
+ * [style transfer implementation](https://github.com/woodrush/neural-art-tf)
98
+ * [painter-by-numbers dataset](https://www.kaggle.com/c/painter-by-numbers/data)
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.py ADDED
File without changes
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__init__.pyc ADDED
Binary file (161 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (157 Bytes). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/data_loading_functions.cpython-36.pyc ADDED
Binary file (4.56 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deep_heatmaps_model_fusion_net.cpython-36.pyc ADDED
Binary file (21.6 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/deformation_functions.cpython-36.pyc ADDED
Binary file (9 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/logging_functions.cpython-36.pyc ADDED
Binary file (5.81 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/menpo_functions.cpython-36.pyc ADDED
Binary file (9.22 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/ops.cpython-36.pyc ADDED
Binary file (3.6 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/__pycache__/pdm_clm_functions.cpython-36.pyc ADDED
Binary file (6.34 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/crop_training_set.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.misc import imsave
2
+ from menpo_functions import *
3
+ from data_loading_functions import *
4
+
5
+
6
+ # define paths & parameters for cropping dataset
7
+ img_dir = '~/landmark_detection_datasets/'
8
+ dataset = 'training'
9
+ bb_type = 'gt'
10
+ margin = 0.25
11
+ image_size = 256
12
+
13
+ # load bounding boxes
14
+ bb_dir = os.path.join(img_dir, 'Bounding_Boxes')
15
+ bb_dictionary = load_bb_dictionary(bb_dir, mode='TRAIN', test_data=dataset)
16
+
17
+ # directory for saving face crops
18
+ outdir = os.path.join(img_dir, 'crop_'+bb_type+'_margin_'+str(margin))
19
+ if not os.path.exists(outdir):
20
+ os.mkdir(outdir)
21
+
22
+ # load images
23
+ imgs_to_crop = load_menpo_image_list(
24
+ img_dir=img_dir, train_crop_dir=None, img_dir_ns=None, mode='TRAIN', bb_dictionary=bb_dictionary,
25
+ image_size=image_size, margin=margin, bb_type=bb_type, augment_basic=False)
26
+
27
+ # save cropped images with matching landmarks
28
+ print ("\ncropping dataset from: "+os.path.join(img_dir, dataset))
29
+ print ("\nsaving cropped dataset to: "+outdir)
30
+ for im in imgs_to_crop:
31
+ if im.pixels.shape[0] == 1:
32
+ im_pixels = gray2rgb(np.squeeze(im.pixels))
33
+ else:
34
+ im_pixels = np.rollaxis(im.pixels, 0, 3)
35
+ imsave(os.path.join(outdir, im.path.name.split('.')[0]+'.png'), im_pixels)
36
+ mio.export_landmark_file(im.landmarks['PTS'], os.path.join(outdir, im.path.name.split('.')[0]+'.pts'))
37
+
38
+ print ("\ncropping dataset completed!")
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ from skimage.color import gray2rgb
4
+
5
+
6
+ def train_val_shuffle_inds_per_epoch(valid_inds, train_inds, train_iter, batch_size, log_path, save_log=True):
7
+ """shuffle image indices for each training epoch and save to log"""
8
+
9
+ np.random.seed(0)
10
+ num_train_images = len(train_inds)
11
+ num_epochs = int(np.ceil((1. * train_iter) / (1. * num_train_images / batch_size)))+1
12
+ epoch_inds_shuffle = np.zeros((num_epochs, num_train_images)).astype(int)
13
+ img_inds = np.arange(num_train_images)
14
+ for i in range(num_epochs):
15
+ np.random.shuffle(img_inds)
16
+ epoch_inds_shuffle[i, :] = img_inds
17
+
18
+ if save_log:
19
+ with open(os.path.join(log_path, "train_val_shuffle_inds.csv"), "wb") as f:
20
+ if valid_inds is not None:
21
+ f.write(b'valid inds\n')
22
+ np.savetxt(f, valid_inds.reshape(1, -1), fmt='%i', delimiter=",")
23
+ f.write(b'train inds\n')
24
+ np.savetxt(f, train_inds.reshape(1, -1), fmt='%i', delimiter=",")
25
+ f.write(b'shuffle inds\n')
26
+ np.savetxt(f, epoch_inds_shuffle, fmt='%i', delimiter=",")
27
+
28
+ return epoch_inds_shuffle
29
+
30
+
31
+ def gaussian(x, y, x0, y0, sigma=6):
32
+ return 1./(np.sqrt(2*np.pi)*sigma) * np.exp(-0.5 * ((x-x0)**2 + (y-y0)**2) / sigma**2)
33
+
34
+
35
+ def create_gaussian_filter(sigma=6, win_mult=3.5):
36
+ win_size = int(win_mult * sigma)
37
+ x, y = np.mgrid[0:2*win_size+1, 0:2*win_size+1]
38
+ gauss_filt = (8./3)*sigma*gaussian(x, y, win_size, win_size, sigma=sigma) # same as in ECT
39
+ return gauss_filt
40
+
41
+
42
+ def load_images(img_list, batch_inds, image_size=256, c_dim=3, scale=255):
43
+
44
+ """ load images as a numpy array from menpo image list """
45
+
46
+ num_inputs = len(batch_inds)
47
+ batch_menpo_images = img_list[batch_inds]
48
+
49
+ images = np.zeros([num_inputs, image_size, image_size, c_dim]).astype('float32')
50
+
51
+ for ind, img in enumerate(batch_menpo_images):
52
+ if img.n_channels < 3 and c_dim == 3:
53
+ images[ind, :, :, :] = gray2rgb(img.pixels_with_channels_at_back())
54
+ else:
55
+ images[ind, :, :, :] = img.pixels_with_channels_at_back()
56
+
57
+ if scale is 255:
58
+ images *= 255
59
+ elif scale is 0:
60
+ images = 2 * images - 1
61
+
62
+ return images
63
+
64
+
65
+ # loading functions with pre-allocation and approx heat-map generation
66
+
67
+
68
+ def create_approx_heat_maps_alloc_once(landmarks, maps, gauss_filt=None, win_mult=3.5, num_landmarks=68, image_size=256,
69
+ sigma=6):
70
+ """ create heatmaps from input landmarks"""
71
+ maps.fill(0.)
72
+
73
+ win_size = int(win_mult * sigma)
74
+ filt_size = 2 * win_size + 1
75
+ landmarks = landmarks.astype(int)
76
+
77
+ if gauss_filt is None:
78
+ x_small, y_small = np.mgrid[0:2 * win_size + 1, 0:2 * win_size + 1]
79
+ gauss_filt = (8. / 3) * sigma * gaussian(x_small, y_small, win_size, win_size, sigma=sigma) # same as in ECT
80
+
81
+ for i in range(num_landmarks):
82
+
83
+ min_row = landmarks[i, 0] - win_size
84
+ max_row = landmarks[i, 0] + win_size + 1
85
+ min_col = landmarks[i, 1] - win_size
86
+ max_col = landmarks[i, 1] + win_size + 1
87
+
88
+ if min_row < 0:
89
+ min_row_gap = -1 * min_row
90
+ min_row = 0
91
+ else:
92
+ min_row_gap = 0
93
+
94
+ if min_col < 0:
95
+ min_col_gap = -1 * min_col
96
+ min_col = 0
97
+ else:
98
+ min_col_gap = 0
99
+
100
+ if max_row > image_size:
101
+ max_row_gap = max_row - image_size
102
+ max_row = image_size
103
+ else:
104
+ max_row_gap = 0
105
+
106
+ if max_col > image_size:
107
+ max_col_gap = max_col - image_size
108
+ max_col = image_size
109
+ else:
110
+ max_col_gap = 0
111
+
112
+ maps[min_row:max_row, min_col:max_col, i] =\
113
+ gauss_filt[min_row_gap:filt_size - 1 * max_row_gap, min_col_gap:filt_size - 1 * max_col_gap]
114
+
115
+
116
+ def load_images_landmarks_approx_maps_alloc_once(
117
+ img_list, batch_inds, images, maps_small, maps, landmarks, image_size=256, num_landmarks=68,
118
+ scale=255, gauss_filt_large=None, gauss_filt_small=None, win_mult=3.5, sigma=6, save_landmarks=False):
119
+
120
+ """ load images and gt landmarks from menpo image list, and create matching heatmaps """
121
+
122
+ batch_menpo_images = img_list[batch_inds]
123
+ c_dim = images.shape[-1]
124
+ grp_name = batch_menpo_images[0].landmarks.group_labels[0]
125
+
126
+ win_size_large = int(win_mult * sigma)
127
+ win_size_small = int(win_mult * (1.*sigma/4))
128
+
129
+ if gauss_filt_small is None:
130
+ x_small, y_small = np.mgrid[0:2 * win_size_small + 1, 0:2 * win_size_small + 1]
131
+ gauss_filt_small = (8. / 3) * (1.*sigma/4) * gaussian(
132
+ x_small, y_small, win_size_small, win_size_small, sigma=1.*sigma/4) # same as in ECT
133
+ if gauss_filt_large is None:
134
+ x_large, y_large = np.mgrid[0:2 * win_size_large + 1, 0:2 * win_size_large + 1]
135
+ gauss_filt_large = (8. / 3) * sigma * gaussian(x_large, y_large, win_size_large, win_size_large, sigma=sigma) # same as in ECT
136
+
137
+ for ind, img in enumerate(batch_menpo_images):
138
+ if img.n_channels < 3 and c_dim == 3:
139
+ images[ind, :, :, :] = gray2rgb(img.pixels_with_channels_at_back())
140
+ else:
141
+ images[ind, :, :, :] = img.pixels_with_channels_at_back()
142
+
143
+ lms = img.landmarks[grp_name].points
144
+ lms = np.minimum(lms, image_size - 1)
145
+ create_approx_heat_maps_alloc_once(
146
+ landmarks=lms, maps=maps[ind, :, :, :], gauss_filt=gauss_filt_large, win_mult=win_mult,
147
+ num_landmarks=num_landmarks, image_size=image_size, sigma=sigma)
148
+
149
+ lms_small = img.resize([image_size / 4, image_size / 4]).landmarks[grp_name].points
150
+ lms_small = np.minimum(lms_small, image_size / 4 - 1)
151
+ create_approx_heat_maps_alloc_once(
152
+ landmarks=lms_small, maps=maps_small[ind, :, :, :], gauss_filt=gauss_filt_small, win_mult=win_mult,
153
+ num_landmarks=num_landmarks, image_size=image_size / 4, sigma=1. * sigma / 4)
154
+
155
+ if save_landmarks:
156
+ landmarks[ind, :, :] = lms
157
+
158
+ if scale is 255:
159
+ images *= 255
160
+ elif scale is 0:
161
+ images = 2 * images - 1
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/data_loading_functions.pyc ADDED
Binary file (5.95 kB). View file
 
marlenezw/audio-driven-animations/MakeItTalk/face_of_art/deep_heatmaps_model_fusion_net.py ADDED
@@ -0,0 +1,872 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy.io
2
+ import scipy.misc
3
+ from glob import glob
4
+ import os
5
+ import numpy as np
6
+ from thirdparty.face_of_art.ops import *
7
+ import tensorflow as tf
8
+ from tensorflow import contrib
9
+ from thirdparty.face_of_art.menpo_functions import *
10
+ from thirdparty.face_of_art.logging_functions import *
11
+ from thirdparty.face_of_art.data_loading_functions import *
12
+
13
+
14
+ class DeepHeatmapsModel(object):
15
+
16
+ """facial landmark localization Network"""
17
+
18
+ def __init__(self, mode='TRAIN', train_iter=100000, batch_size=10, learning_rate=1e-3, l_weight_primary=1.,
19
+ l_weight_fusion=1.,l_weight_upsample=3.,adam_optimizer=True,momentum=0.95,step=100000, gamma=0.1,reg=0,
20
+ weight_initializer='xavier', weight_initializer_std=0.01, bias_initializer=0.0, image_size=256,c_dim=3,
21
+ num_landmarks=68, sigma=1.5, scale=1, margin=0.25, bb_type='gt', win_mult=3.33335,
22
+ augment_basic=True,augment_texture=False, p_texture=0., augment_geom=False, p_geom=0.,
23
+ output_dir='output', save_model_path='model',
24
+ save_sample_path='sample', save_log_path='logs', test_model_path='model/deep_heatmaps-50000',
25
+ pre_train_path='model/deep_heatmaps-50000', load_pretrain=False, load_primary_only=False,
26
+ img_path='data', test_data='full', valid_data='full', valid_size=0, log_valid_every=5,
27
+ train_crop_dir='crop_gt_margin_0.25', img_dir_ns='crop_gt_margin_0.25_ns',
28
+ print_every=100, save_every=5000, sample_every=5000, sample_grid=9, sample_to_log=True,
29
+ debug_data_size=20, debug=False, epoch_data_dir='epoch_data', use_epoch_data=False, menpo_verbose=True):
30
+
31
+ # define some extra parameters
32
+
33
+ self.log_histograms = False # save weight + gradient histogram to log
34
+ self.save_valid_images = True # sample heat maps of validation images
35
+ self.sample_per_channel = False # sample heatmaps separately for each landmark
36
+
37
+ # for fine-tuning, choose reset_training_op==True. when resuming training, reset_training_op==False
38
+ self.reset_training_op = False
39
+
40
+ self.fast_img_gen = True
41
+
42
+ self.compute_nme = True # compute normalized mean error
43
+
44
+ self.config = tf.ConfigProto()
45
+ self.config.gpu_options.allow_growth = True
46
+
47
+ # sampling and logging parameters
48
+ self.print_every = print_every # print losses to screen + log
49
+ self.save_every = save_every # save model
50
+ self.sample_every = sample_every # save images of gen heat maps compared to GT
51
+ self.sample_grid = sample_grid # number of training images in sample
52
+ self.sample_to_log = sample_to_log # sample images to log instead of disk
53
+ self.log_valid_every = log_valid_every # log validation loss (in epochs)
54
+
55
+ self.debug = debug
56
+ self.debug_data_size = debug_data_size
57
+ self.use_epoch_data = use_epoch_data
58
+ self.epoch_data_dir = epoch_data_dir
59
+
60
+ self.load_pretrain = load_pretrain
61
+ self.load_primary_only = load_primary_only
62
+ self.pre_train_path = pre_train_path
63
+
64
+ self.mode = mode
65
+ self.train_iter = train_iter
66
+ self.learning_rate = learning_rate
67
+
68
+ self.image_size = image_size
69
+ self.c_dim = c_dim
70
+ self.batch_size = batch_size
71
+
72
+ self.num_landmarks = num_landmarks
73
+
74
+ self.save_log_path = save_log_path
75
+ self.save_sample_path = save_sample_path
76
+ self.save_model_path = save_model_path
77
+ self.test_model_path = test_model_path
78
+ self.img_path=img_path
79
+
80
+ self.momentum = momentum
81
+ self.step = step # for lr decay
82
+ self.gamma = gamma # for lr decay
83
+ self.reg = reg # weight decay scale
84
+ self.l_weight_primary = l_weight_primary # primary loss weight
85
+ self.l_weight_fusion = l_weight_fusion # fusion loss weight
86
+ self.l_weight_upsample = l_weight_upsample # upsample loss weight
87
+
88
+ self.weight_initializer = weight_initializer # random_normal or xavier
89
+ self.weight_initializer_std = weight_initializer_std
90
+ self.bias_initializer = bias_initializer
91
+ self.adam_optimizer = adam_optimizer
92
+
93
+ self.sigma = sigma # sigma for heatmap generation
94
+ self.scale = scale # scale for image normalization 255 / 1 / 0
95
+ self.win_mult = win_mult # gaussian filter size for cpu/gpu approximation: 2 * sigma * win_mult + 1
96
+
97
+ self.test_data = test_data # if mode is TEST, this choose the set to use full/common/challenging/test/art
98
+ self.train_crop_dir = train_crop_dir
99
+ self.img_dir_ns = os.path.join(img_path,img_dir_ns)
100
+ self.augment_basic = augment_basic # perform basic augmentation (rotation,flip,crop)
101
+ self.augment_texture = augment_texture # perform artistic texture augmentation (NS)
102
+ self.p_texture = p_texture # initial probability of artistic texture augmentation
103
+ self.augment_geom = augment_geom # perform artistic geometric augmentation
104
+ self.p_geom = p_geom # initial probability of artistic geometric augmentation
105
+
106
+ self.valid_size = valid_size
107
+ self.valid_data = valid_data
108
+
109
+ # load image, bb and landmark data using menpo
110
+ self.bb_dir = os.path.join(img_path, 'Bounding_Boxes')
111
+ self.bb_dictionary = load_bb_dictionary(self.bb_dir, mode, test_data=self.test_data)
112
+
113
+ # use pre-augmented data, to save time during training
114
+ if self.use_epoch_data:
115
+ epoch_0 = os.path.join(self.epoch_data_dir, '0')
116
+ self.img_menpo_list = load_menpo_image_list(
117
+ img_path, train_crop_dir=epoch_0, img_dir_ns=None, mode=mode, bb_dictionary=self.bb_dictionary,
118
+ image_size=self.image_size, test_data=self.test_data, augment_basic=False, augment_texture=False,
119
+ augment_geom=False, verbose=menpo_verbose)
120
+ else:
121
+ self.img_menpo_list = load_menpo_image_list(
122
+ img_path, train_crop_dir, self.img_dir_ns, mode, bb_dictionary=self.bb_dictionary,
123
+ image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.test_data,
124
+ augment_basic=augment_basic, augment_texture=augment_texture, p_texture=p_texture,
125
+ augment_geom=augment_geom, p_geom=p_geom, verbose=menpo_verbose)
126
+
127
+ if mode == 'TRAIN':
128
+
129
+ train_params = locals()
130
+ print_training_params_to_file(train_params) # save init parameters
131
+
132
+ self.train_inds = np.arange(len(self.img_menpo_list))
133
+
134
+ if self.debug:
135
+ self.train_inds = self.train_inds[:self.debug_data_size]
136
+ self.img_menpo_list = self.img_menpo_list[self.train_inds]
137
+
138
+ if valid_size > 0:
139
+
140
+ self.valid_bb_dictionary = load_bb_dictionary(self.bb_dir, 'TEST', test_data=self.valid_data)
141
+ self.valid_img_menpo_list = load_menpo_image_list(
142
+ img_path, train_crop_dir, self.img_dir_ns, 'TEST', bb_dictionary=self.valid_bb_dictionary,
143
+ image_size=self.image_size, margin=margin, bb_type=bb_type, test_data=self.valid_data,
144
+ verbose=menpo_verbose)
145
+
146
+ np.random.seed(0)
147
+ self.val_inds = np.arange(len(self.valid_img_menpo_list))
148
+ np.random.shuffle(self.val_inds)
149
+ self.val_inds = self.val_inds[:self.valid_size]
150
+
151
+ self.valid_img_menpo_list = self.valid_img_menpo_list[self.val_inds]
152
+
153
+ self.valid_images_loaded =\
154
+ np.zeros([self.valid_size, self.image_size, self.image_size, self.c_dim]).astype('float32')
155
+ self.valid_gt_maps_small_loaded =\
156
+ np.zeros([self.valid_size, self.image_size / 4, self.image_size / 4,
157
+ self.num_landmarks]).astype('float32')
158
+ self.valid_gt_maps_loaded =\
159
+ np.zeros([self.valid_size, self.image_size, self.image_size, self.num_landmarks]
160
+ ).astype('float32')
161
+ self.valid_landmarks_loaded = np.zeros([self.valid_size, num_landmarks, 2]).astype('float32')
162
+ self.valid_landmarks_pred = np.zeros([self.valid_size, self.num_landmarks, 2]).astype('float32')
163
+
164
+ load_images_landmarks_approx_maps_alloc_once(
165
+ self.valid_img_menpo_list, np.arange(self.valid_size), images=self.valid_images_loaded,
166
+ maps_small=self.valid_gt_maps_small_loaded, maps=self.valid_gt_maps_loaded,
167
+ landmarks=self.valid_landmarks_loaded, image_size=self.image_size,
168
+ num_landmarks=self.num_landmarks, scale=self.scale, win_mult=self.win_mult, sigma=self.sigma,
169
+ save_landmarks=self.compute_nme)
170
+
171
+ if self.valid_size > self.sample_grid:
172
+ self.valid_gt_maps_loaded = self.valid_gt_maps_loaded[:self.sample_grid]
173
+ self.valid_gt_maps_small_loaded = self.valid_gt_maps_small_loaded[:self.sample_grid]
174
+ else:
175
+ self.val_inds = None
176
+
177
+ self.epoch_inds_shuffle = train_val_shuffle_inds_per_epoch(
178
+ self.val_inds, self.train_inds, train_iter, batch_size, save_log_path)
179
+
180
+ def add_placeholders(self):
181
+
182
+ if self.mode == 'TEST':
183
+ self.images = tf.placeholder(
184
+ tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'images')
185
+
186
+ self.heatmaps = tf.placeholder(
187
+ tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'heatmaps')
188
+
189
+ self.heatmaps_small = tf.placeholder(
190
+ tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'heatmaps_small')
191
+ self.lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'lms')
192
+ self.pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'pred_lms')
193
+
194
+ elif self.mode == 'TRAIN':
195
+ self.images = tf.placeholder(
196
+ tf.float32, [None, self.image_size, self.image_size, self.c_dim], 'train_images')
197
+
198
+ self.heatmaps = tf.placeholder(
199
+ tf.float32, [None, self.image_size, self.image_size, self.num_landmarks], 'train_heatmaps')
200
+
201
+ self.heatmaps_small = tf.placeholder(
202
+ tf.float32, [None, int(self.image_size/4), int(self.image_size/4), self.num_landmarks], 'train_heatmaps_small')
203
+
204
+ self.train_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_lms')
205
+ self.train_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'train_pred_lms')
206
+
207
+ self.valid_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_lms')
208
+ self.valid_pred_lms = tf.placeholder(tf.float32, [None, self.num_landmarks, 2], 'valid_pred_lms')
209
+
210
+ # self.p_texture_log = tf.placeholder(tf.float32, [])
211
+ # self.p_geom_log = tf.placeholder(tf.float32, [])
212
+
213
+ # self.sparse_hm_small = tf.placeholder(tf.float32, [None, int(self.image_size/4), int(self.image_size/4), 1])
214
+ # self.sparse_hm = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, 1])
215
+
216
+ if self.sample_to_log:
217
+ row = int(np.sqrt(self.sample_grid))
218
+ self.log_image_map_small = tf.placeholder(
219
+ tf.uint8, [None, row * int(self.image_size/4), 3 * row * int(self.image_size/4), self.c_dim],
220
+ 'sample_img_map_small')
221
+ self.log_image_map = tf.placeholder(
222
+ tf.uint8, [None, row * self.image_size, 3 * row * self.image_size, self.c_dim],
223
+ 'sample_img_map')
224
+ if self.sample_per_channel:
225
+ row = np.ceil(np.sqrt(self.num_landmarks)).astype(np.int64)
226
+ self.log_map_channels_small = tf.placeholder(
227
+ tf.uint8, [None, row * int(self.image_size/4), 2 * row * int(self.image_size/4), self.c_dim],
228
+ 'sample_map_channels_small')
229
+ self.log_map_channels = tf.placeholder(
230
+ tf.uint8, [None, row * self.image_size, 2 * row * self.image_size, self.c_dim],
231
+ 'sample_map_channels')
232
+
233
+ def heatmaps_network(self, input_images, reuse=None, name='pred_heatmaps'):
234
+
235
+ with tf.name_scope(name):
236
+
237
+ if self.weight_initializer == 'xavier':
238
+ weight_initializer = contrib.layers.xavier_initializer()
239
+ else:
240
+ weight_initializer = tf.random_normal_initializer(stddev=self.weight_initializer_std)
241
+
242
+ bias_init = tf.constant_initializer(self.bias_initializer)
243
+
244
+ with tf.variable_scope('heatmaps_network'):
245
+ with tf.name_scope('primary_net'):
246
+
247
+ l1 = conv_relu_pool(input_images, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
248
+ reuse=reuse, var_scope='conv_1')
249
+ l2 = conv_relu_pool(l1, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
250
+ reuse=reuse, var_scope='conv_2')
251
+ l3 = conv_relu(l2, 5, 128, conv_ker_init=weight_initializer, conv_bias_init=bias_init,
252
+ reuse=reuse, var_scope='conv_3')
253
+
254
+ l4_1 = conv_relu(l3, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer,
255
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_1')
256
+ l4_2 = conv_relu(l3, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer,
257
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_2')
258
+ l4_3 = conv_relu(l3, 3, 128, conv_dilation=3, conv_ker_init=weight_initializer,
259
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_3')
260
+ l4_4 = conv_relu(l3, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer,
261
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_4_4')
262
+
263
+ l4 = tf.concat([l4_1, l4_2, l4_3, l4_4], 3, name='conv_4')
264
+
265
+ l5_1 = conv_relu(l4, 3, 256, conv_dilation=1, conv_ker_init=weight_initializer,
266
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_1')
267
+ l5_2 = conv_relu(l4, 3, 256, conv_dilation=2, conv_ker_init=weight_initializer,
268
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_2')
269
+ l5_3 = conv_relu(l4, 3, 256, conv_dilation=3, conv_ker_init=weight_initializer,
270
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_3')
271
+ l5_4 = conv_relu(l4, 3, 256, conv_dilation=4, conv_ker_init=weight_initializer,
272
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_5_4')
273
+
274
+ l5 = tf.concat([l5_1, l5_2, l5_3, l5_4], 3, name='conv_5')
275
+
276
+ l6 = conv_relu(l5, 1, 512, conv_ker_init=weight_initializer,
277
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_6')
278
+ l7 = conv_relu(l6, 1, 256, conv_ker_init=weight_initializer,
279
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_7')
280
+ primary_out = conv(l7, 1, self.num_landmarks, conv_ker_init=weight_initializer,
281
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_8')
282
+
283
+ with tf.name_scope('fusion_net'):
284
+
285
+ l_fsn_0 = tf.concat([l3, l7], 3, name='conv_3_7_fsn')
286
+
287
+ l_fsn_1_1 = conv_relu(l_fsn_0, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer,
288
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_1')
289
+ l_fsn_1_2 = conv_relu(l_fsn_0, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer,
290
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_2')
291
+ l_fsn_1_3 = conv_relu(l_fsn_0, 3, 64, conv_dilation=3, conv_ker_init=weight_initializer,
292
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_1_3')
293
+
294
+ l_fsn_1 = tf.concat([l_fsn_1_1, l_fsn_1_2, l_fsn_1_3], 3, name='conv_fsn_1')
295
+
296
+ l_fsn_2_1 = conv_relu(l_fsn_1, 3, 64, conv_dilation=1, conv_ker_init=weight_initializer,
297
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_1')
298
+ l_fsn_2_2 = conv_relu(l_fsn_1, 3, 64, conv_dilation=2, conv_ker_init=weight_initializer,
299
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_2')
300
+ l_fsn_2_3 = conv_relu(l_fsn_1, 3, 64, conv_dilation=4, conv_ker_init=weight_initializer,
301
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_3')
302
+ l_fsn_2_4 = conv_relu(l_fsn_1, 5, 64, conv_dilation=3, conv_ker_init=weight_initializer,
303
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_2_4')
304
+
305
+ l_fsn_2 = tf.concat([l_fsn_2_1, l_fsn_2_2, l_fsn_2_3, l_fsn_2_4], 3, name='conv_fsn_2')
306
+
307
+ l_fsn_3_1 = conv_relu(l_fsn_2, 3, 128, conv_dilation=1, conv_ker_init=weight_initializer,
308
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_1')
309
+ l_fsn_3_2 = conv_relu(l_fsn_2, 3, 128, conv_dilation=2, conv_ker_init=weight_initializer,
310
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_2')
311
+ l_fsn_3_3 = conv_relu(l_fsn_2, 3, 128, conv_dilation=4, conv_ker_init=weight_initializer,
312
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_3')
313
+ l_fsn_3_4 = conv_relu(l_fsn_2, 5, 128, conv_dilation=3, conv_ker_init=weight_initializer,
314
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_3_4')
315
+
316
+ l_fsn_3 = tf.concat([l_fsn_3_1, l_fsn_3_2, l_fsn_3_3, l_fsn_3_4], 3, name='conv_fsn_3')
317
+
318
+ l_fsn_4 = conv_relu(l_fsn_3, 1, 256, conv_ker_init=weight_initializer,
319
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_4')
320
+ fusion_out = conv(l_fsn_4, 1, self.num_landmarks, conv_ker_init=weight_initializer,
321
+ conv_bias_init=bias_init, reuse=reuse, var_scope='conv_fsn_5')
322
+
323
+ with tf.name_scope('upsample_net'):
324
+
325
+ out = deconv(fusion_out, 8, self.num_landmarks, conv_stride=4,
326
+ conv_ker_init=deconv2d_bilinear_upsampling_initializer(
327
+ [8, 8, self.num_landmarks, self.num_landmarks]), conv_bias_init=bias_init,
328
+ reuse=reuse, var_scope='deconv_1')
329
+
330
+ self.all_layers = [l1, l2, l3, l4, l5, l6, l7, primary_out, l_fsn_1, l_fsn_2, l_fsn_3, l_fsn_4,
331
+ fusion_out, out]
332
+
333
+ return primary_out, fusion_out, out
334
+
335
+ def build_model(self):
336
+ self.pred_hm_p, self.pred_hm_f, self.pred_hm_u = self.heatmaps_network(self.images,name='heatmaps_prediction')
337
+
338
+ def create_loss_ops(self):
339
+
340
+ def nme_norm_eyes(pred_landmarks, real_landmarks, normalize=True, name='NME'):
341
+ """calculate normalized mean error on landmarks - normalize with inter pupil distance"""
342
+
343
+ with tf.name_scope(name):
344
+ with tf.name_scope('real_pred_landmarks_rmse'):
345
+ # calculate RMS ERROR between GT and predicted lms
346
+ landmarks_rms_err = tf.reduce_mean(
347
+ tf.sqrt(tf.reduce_sum(tf.square(pred_landmarks - real_landmarks), axis=2)), axis=1)
348
+ if normalize:
349
+ # normalize RMS ERROR with inter-pupil distance of GT lms
350
+ with tf.name_scope('inter_pupil_dist'):
351
+ with tf.name_scope('left_eye_center'):
352
+ p1 = tf.reduce_mean(tf.slice(real_landmarks, [0, 42, 0], [-1, 6, 2]), axis=1)
353
+ with tf.name_scope('right_eye_center'):
354
+ p2 = tf.reduce_mean(tf.slice(real_landmarks, [0, 36, 0], [-1, 6, 2]), axis=1)
355
+
356
+ eye_dist = tf.sqrt(tf.reduce_sum(tf.square(p1 - p2), axis=1))
357
+
358
+ return landmarks_rms_err / eye_dist
359
+ else:
360
+ return landmarks_rms_err
361
+
362
+ if self.mode is 'TRAIN':
363
+
364
+ # calculate L2 loss between ideal and predicted heatmaps
365
+ primary_maps_diff = self.pred_hm_p - self.heatmaps_small
366
+ fusion_maps_diff = self.pred_hm_f - self.heatmaps_small
367
+ upsample_maps_diff = self.pred_hm_u - self.heatmaps
368
+
369
+ self.l2_primary = tf.reduce_mean(tf.square(primary_maps_diff))
370
+ self.l2_fusion = tf.reduce_mean(tf.square(fusion_maps_diff))
371
+ self.l2_upsample = tf.reduce_mean(tf.square(upsample_maps_diff))
372
+
373
+ self.total_loss = 1000.*(self.l_weight_primary * self.l2_primary + self.l_weight_fusion * self.l2_fusion +
374
+ self.l_weight_upsample * self.l2_upsample)
375
+
376
+ # add weight decay
377
+ self.total_loss += self.reg * tf.add_n(
378
+ [tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name])
379
+
380
+ # compute normalized mean error on gt vs. predicted landmarks (for validation)
381
+ if self.compute_nme:
382
+ self.nme_loss = tf.reduce_mean(nme_norm_eyes(self.train_pred_lms, self.train_lms))
383
+
384
+ if self.valid_size > 0 and self.compute_nme:
385
+ self.valid_nme_loss = tf.reduce_mean(nme_norm_eyes(self.valid_pred_lms, self.valid_lms))
386
+
387
+ elif self.mode == 'TEST' and self.compute_nme:
388
+ self.nme_per_image = nme_norm_eyes(self.pred_lms, self.lms)
389
+ self.nme_loss = tf.reduce_mean(self.nme_per_image)
390
+
391
+ def predict_valid_landmarks_in_batches(self, images, session):
392
+
393
+ num_images=int(images.shape[0])
394
+ num_batches = int(1.*num_images/self.batch_size)
395
+ if num_batches == 0:
396
+ batch_size = num_images
397
+ num_batches = 1
398
+ else:
399
+ batch_size = self.batch_size
400
+
401
+ for j in range(num_batches):
402
+
403
+ batch_images = images[j * batch_size:(j + 1) * batch_size,:,:,:]
404
+ batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images})
405
+ batch_heat_maps_to_landmarks_alloc_once(
406
+ batch_maps=batch_maps_pred, batch_landmarks=self.valid_landmarks_pred[j * batch_size:(j + 1) * batch_size, :, :],
407
+ batch_size=batch_size,image_size=self.image_size,num_landmarks=self.num_landmarks)
408
+
409
+ reminder = num_images-num_batches*batch_size
410
+ if reminder > 0:
411
+ batch_images = images[-reminder:, :, :, :]
412
+ batch_maps_pred = session.run(self.pred_hm_u, {self.images: batch_images})
413
+
414
+ batch_heat_maps_to_landmarks_alloc_once(
415
+ batch_maps=batch_maps_pred,
416
+ batch_landmarks=self.valid_landmarks_pred[-reminder:, :, :],
417
+ batch_size=reminder, image_size=self.image_size, num_landmarks=self.num_landmarks)
418
+
419
+ def create_summary_ops(self):
420
+ """create summary ops for logging"""
421
+
422
+ # loss summary
423
+ l2_primary = tf.summary.scalar('l2_primary', self.l2_primary)
424
+ l2_fusion = tf.summary.scalar('l2_fusion', self.l2_fusion)
425
+ l2_upsample = tf.summary.scalar('l2_upsample', self.l2_upsample)
426
+
427
+ l_total = tf.summary.scalar('l_total', self.total_loss)
428
+ self.batch_summary_op = tf.summary.merge([l2_primary,l2_fusion,l2_upsample,l_total])
429
+
430
+ if self.compute_nme:
431
+ nme = tf.summary.scalar('nme', self.nme_loss)
432
+ self.batch_summary_op = tf.summary.merge([self.batch_summary_op, nme])
433
+
434
+ if self.log_histograms:
435
+ var_summary = [tf.summary.histogram(var.name,var) for var in tf.trainable_variables()]
436
+ grads = tf.gradients(self.total_loss, tf.trainable_variables())
437
+ grads = list(zip(grads, tf.trainable_variables()))
438
+ grad_summary = [tf.summary.histogram(var.name+'/grads',grad) for grad,var in grads]
439
+ activ_summary = [tf.summary.histogram(layer.name, layer) for layer in self.all_layers]
440
+ self.batch_summary_op = tf.summary.merge([self.batch_summary_op, var_summary, grad_summary, activ_summary])
441
+
442
+ if self.valid_size > 0 and self.compute_nme:
443
+ self.valid_summary = tf.summary.scalar('valid_nme', self.valid_nme_loss)
444
+
445
+ if self.sample_to_log:
446
+ img_map_summary_small = tf.summary.image('compare_map_to_gt_small', self.log_image_map_small)
447
+ img_map_summary = tf.summary.image('compare_map_to_gt', self.log_image_map)
448
+
449
+ if self.sample_per_channel:
450
+ map_channels_summary = tf.summary.image('compare_map_channels_to_gt', self.log_map_channels)
451
+ map_channels_summary_small = tf.summary.image('compare_map_channels_to_gt_small',
452
+ self.log_map_channels_small)
453
+ self.img_summary = tf.summary.merge(
454
+ [img_map_summary, img_map_summary_small,map_channels_summary,map_channels_summary_small])
455
+ else:
456
+ self.img_summary = tf.summary.merge([img_map_summary, img_map_summary_small])
457
+
458
+ if self.valid_size >= self.sample_grid:
459
+ img_map_summary_valid_small = tf.summary.image('compare_map_to_gt_small_valid', self.log_image_map_small)
460
+ img_map_summary_valid = tf.summary.image('compare_map_to_gt_valid', self.log_image_map)
461
+
462
+ if self.sample_per_channel:
463
+ map_channels_summary_valid_small = tf.summary.image('compare_map_channels_to_gt_small_valid',
464
+ self.log_map_channels_small)
465
+ map_channels_summary_valid = tf.summary.image('compare_map_channels_to_gt_valid',
466
+ self.log_map_channels)
467
+ self.img_summary_valid = tf.summary.merge(
468
+ [img_map_summary_valid,img_map_summary_valid_small,map_channels_summary_valid,
469
+ map_channels_summary_valid_small])
470
+ else:
471
+ self.img_summary_valid = tf.summary.merge([img_map_summary_valid, img_map_summary_valid_small])
472
+
473
+ def train(self):
474
+ # set random seed
475
+ tf.set_random_seed(1234)
476
+ np.random.seed(1234)
477
+ # build a graph
478
+ # add placeholders
479
+ self.add_placeholders()
480
+ # build model
481
+ self.build_model()
482
+ # create loss ops
483
+ self.create_loss_ops()
484
+ # create summary ops
485
+ self.create_summary_ops()
486
+
487
+ # create optimizer and training op
488
+ global_step = tf.Variable(0, trainable=False)
489
+ lr = tf.train.exponential_decay(self.learning_rate,global_step, self.step, self.gamma, staircase=True)
490
+ if self.adam_optimizer:
491
+ optimizer = tf.train.AdamOptimizer(lr)
492
+ else:
493
+ optimizer = tf.train.MomentumOptimizer(lr, self.momentum)
494
+
495
+ train_op = optimizer.minimize(self.total_loss,global_step=global_step)
496
+
497
+ with tf.Session(config=self.config) as sess:
498
+
499
+ tf.global_variables_initializer().run()
500
+
501
+ # load pre trained weights if load_pretrain==True
502
+ if self.load_pretrain:
503
+ print
504
+ print('*** loading pre-trained weights from: '+self.pre_train_path+' ***')
505
+ if self.load_primary_only:
506
+ print('*** loading primary-net only ***')
507
+ primary_var = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if
508
+ ('deconv_' not in v.name) and ('_fsn_' not in v.name)]
509
+ loader = tf.train.Saver(var_list=primary_var)
510
+ else:
511
+ loader = tf.train.Saver()
512
+ loader.restore(sess, self.pre_train_path)
513
+ print("*** Model restore finished, current global step: %d" % global_step.eval())
514
+
515
+ # for fine-tuning, choose reset_training_op==True. when resuming training, reset_training_op==False
516
+ if self.reset_training_op:
517
+ print ("resetting optimizer and global step")
518
+ opt_var_list = [optimizer.get_slot(var, name) for name in optimizer.get_slot_names()
519
+ for var in tf.global_variables() if optimizer.get_slot(var, name) is not None]
520
+ opt_var_list_init = tf.variables_initializer(opt_var_list)
521
+ opt_var_list_init.run()
522
+ sess.run(global_step.initializer)
523
+
524
+ # create model saver and file writer
525
+ summary_writer = tf.summary.FileWriter(logdir=self.save_log_path, graph=tf.get_default_graph())
526
+ saver = tf.train.Saver()
527
+
528
+ print('\n*** Start Training ***')
529
+
530
+ # initialize some variables before training loop
531
+ resume_step = global_step.eval()
532
+ num_train_images = len(self.img_menpo_list)
533
+ batches_in_epoch = int(float(num_train_images) / float(self.batch_size))
534
+ epoch = int(resume_step / batches_in_epoch)
535
+ img_inds = self.epoch_inds_shuffle[epoch, :]
536
+ log_valid = True
537
+ log_valid_images = True
538
+
539
+ # allocate space for batch images, maps and landmarks
540
+ batch_images = np.zeros([self.batch_size, self.image_size, self.image_size, self.c_dim]).astype(
541
+ 'float32')
542
+ batch_lms = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32')
543
+ batch_lms_pred = np.zeros([self.batch_size, self.num_landmarks, 2]).astype('float32')
544
+
545
+ batch_maps_small = np.zeros((self.batch_size, int(self.image_size/4),
546
+ int(self.image_size/4), self.num_landmarks)).astype('float32')
547
+ batch_maps = np.zeros((self.batch_size, self.image_size, self.image_size,
548
+ self.num_landmarks)).astype('float32')
549
+
550
+ # create gaussians for heatmap generation
551
+ gaussian_filt_large = create_gaussian_filter(sigma=self.sigma, win_mult=self.win_mult)
552
+ gaussian_filt_small = create_gaussian_filter(sigma=1.*self.sigma/4, win_mult=self.win_mult)
553
+
554
+ # training loop
555
+ for step in range(resume_step, self.train_iter):
556
+
557
+ j = step % batches_in_epoch # j==0 if we finished an epoch
558
+
559
+ # if we finished an epoch and this isn't the first step
560
+ if step > resume_step and j == 0:
561
+ epoch += 1
562
+ img_inds = self.epoch_inds_shuffle[epoch, :] # get next shuffled image inds
563
+ log_valid = True
564
+ log_valid_images = True
565
+ if self.use_epoch_data: # if using pre-augmented data, load epoch directory
566
+ epoch_dir = os.path.join(self.epoch_data_dir, str(epoch))
567
+ self.img_menpo_list = load_menpo_image_list(
568
+ self.img_path, train_crop_dir=epoch_dir, img_dir_ns=None, mode=self.mode,
569
+ bb_dictionary=self.bb_dictionary, image_size=self.image_size, test_data=self.test_data,
570
+ augment_basic=False, augment_texture=False, augment_geom=False)
571
+
572
+ # get batch indices
573
+ batch_inds = img_inds[j * self.batch_size:(j + 1) * self.batch_size]
574
+
575
+ # load batch images, gt maps and landmarks
576
+ load_images_landmarks_approx_maps_alloc_once(
577
+ self.img_menpo_list, batch_inds, images=batch_images, maps_small=batch_maps_small,
578
+ maps=batch_maps, landmarks=batch_lms, image_size=self.image_size,
579
+ num_landmarks=self.num_landmarks, scale=self.scale, gauss_filt_large=gaussian_filt_large,
580
+ gauss_filt_small=gaussian_filt_small, win_mult=self.win_mult, sigma=self.sigma,
581
+ save_landmarks=self.compute_nme)
582
+
583
+ feed_dict_train = {self.images: batch_images, self.heatmaps: batch_maps,
584
+ self.heatmaps_small: batch_maps_small}
585
+
586
+ # train on batch
587
+ sess.run(train_op, feed_dict_train)
588
+
589
+ # save to log and print status
590
+ if step == resume_step or (step + 1) % self.print_every == 0:
591
+
592
+ # train data log
593
+ if self.compute_nme:
594
+ batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images})
595
+
596
+ batch_heat_maps_to_landmarks_alloc_once(
597
+ batch_maps=batch_maps_pred,batch_landmarks=batch_lms_pred,
598
+ batch_size=self.batch_size, image_size=self.image_size,
599
+ num_landmarks=self.num_landmarks)
600
+
601
+ train_feed_dict_log = {
602
+ self.images: batch_images, self.heatmaps: batch_maps,
603
+ self.heatmaps_small: batch_maps_small, self.train_lms: batch_lms,
604
+ self.train_pred_lms: batch_lms_pred}
605
+
606
+ summary, l_p, l_f, l_t, nme = sess.run(
607
+ [self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss,
608
+ self.nme_loss],
609
+ train_feed_dict_log)
610
+
611
+ print (
612
+ 'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f]'
613
+ ' total loss: [%.6f] NME: [%.6f]' % (
614
+ epoch, step + 1, self.train_iter, l_p, l_f, l_t, nme))
615
+ else:
616
+ train_feed_dict_log = {self.images: batch_images, self.heatmaps: batch_maps,
617
+ self.heatmaps_small: batch_maps_small}
618
+
619
+ summary, l_p, l_f, l_t = sess.run(
620
+ [self.batch_summary_op, self.l2_primary, self.l2_fusion, self.total_loss],
621
+ train_feed_dict_log)
622
+ print (
623
+ 'epoch: [%d] step: [%d/%d] primary loss: [%.6f] fusion loss: [%.6f] total loss: [%.6f]'
624
+ % (epoch, step + 1, self.train_iter, l_p, l_f, l_t))
625
+
626
+ summary_writer.add_summary(summary, step)
627
+
628
+ # valid data log
629
+ if self.valid_size > 0 and (log_valid and epoch % self.log_valid_every == 0) \
630
+ and self.compute_nme:
631
+ log_valid = False
632
+
633
+ self.predict_valid_landmarks_in_batches(self.valid_images_loaded, sess)
634
+ valid_feed_dict_log = {
635
+ self.valid_lms: self.valid_landmarks_loaded,
636
+ self.valid_pred_lms: self.valid_landmarks_pred}
637
+
638
+ v_summary, v_nme = sess.run([self.valid_summary, self.valid_nme_loss],
639
+ valid_feed_dict_log)
640
+ summary_writer.add_summary(v_summary, step)
641
+ print (
642
+ 'epoch: [%d] step: [%d/%d] valid NME: [%.6f]' % (
643
+ epoch, step + 1, self.train_iter, v_nme))
644
+
645
+ # save model
646
+ if (step + 1) % self.save_every == 0:
647
+ saver.save(sess, os.path.join(self.save_model_path, 'deep_heatmaps'), global_step=step + 1)
648
+ print ('model/deep-heatmaps-%d saved' % (step + 1))
649
+
650
+ # save images
651
+ if step == resume_step or (step + 1) % self.sample_every == 0:
652
+
653
+ batch_maps_small_pred = sess.run(self.pred_hm_p, {self.images: batch_images})
654
+ if not self.compute_nme:
655
+ batch_maps_pred = sess.run(self.pred_hm_u, {self.images: batch_images})
656
+ batch_lms_pred = None
657
+
658
+ merged_img = merge_images_landmarks_maps_gt(
659
+ batch_images.copy(), batch_maps_pred, batch_maps, landmarks=batch_lms_pred,
660
+ image_size=self.image_size, num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
661
+ scale=self.scale, circle_size=2, fast=self.fast_img_gen)
662
+
663
+ merged_img_small = merge_images_landmarks_maps_gt(
664
+ batch_images.copy(), batch_maps_small_pred, batch_maps_small,
665
+ image_size=self.image_size,
666
+ num_landmarks=self.num_landmarks, num_samples=self.sample_grid, scale=self.scale,
667
+ circle_size=0, fast=self.fast_img_gen)
668
+
669
+ if self.sample_per_channel:
670
+ map_per_channel = map_comapre_channels(
671
+ batch_images.copy(), batch_maps_pred, batch_maps, image_size=self.image_size,
672
+ num_landmarks=self.num_landmarks, scale=self.scale)
673
+
674
+ map_per_channel_small = map_comapre_channels(
675
+ batch_images.copy(), batch_maps_small_pred, batch_maps_small, image_size=int(self.image_size/4),
676
+ num_landmarks=self.num_landmarks, scale=self.scale)
677
+
678
+ if self.sample_to_log: # save heatmap images to log
679
+ if self.sample_per_channel:
680
+ summary_img = sess.run(
681
+ self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0),
682
+ self.log_map_channels: np.expand_dims(map_per_channel, 0),
683
+ self.log_image_map_small: np.expand_dims(merged_img_small, 0),
684
+ self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)})
685
+ else:
686
+ summary_img = sess.run(
687
+ self.img_summary, {self.log_image_map: np.expand_dims(merged_img, 0),
688
+ self.log_image_map_small: np.expand_dims(merged_img_small, 0)})
689
+ summary_writer.add_summary(summary_img, step)
690
+
691
+ if (self.valid_size >= self.sample_grid) and self.save_valid_images and\
692
+ (log_valid_images and epoch % self.log_valid_every == 0):
693
+ log_valid_images = False
694
+
695
+ batch_maps_small_pred_val,batch_maps_pred_val =\
696
+ sess.run([self.pred_hm_p,self.pred_hm_u],
697
+ {self.images: self.valid_images_loaded[:self.sample_grid]})
698
+
699
+ merged_img_small = merge_images_landmarks_maps_gt(
700
+ self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val,
701
+ self.valid_gt_maps_small_loaded, image_size=self.image_size,
702
+ num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
703
+ scale=self.scale, circle_size=0, fast=self.fast_img_gen)
704
+
705
+ merged_img = merge_images_landmarks_maps_gt(
706
+ self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred_val,
707
+ self.valid_gt_maps_loaded, image_size=self.image_size,
708
+ num_landmarks=self.num_landmarks, num_samples=self.sample_grid,
709
+ scale=self.scale, circle_size=2, fast=self.fast_img_gen)
710
+
711
+ if self.sample_per_channel:
712
+ map_per_channel_small = map_comapre_channels(
713
+ self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_small_pred_val,
714
+ self.valid_gt_maps_small_loaded, image_size=int(self.image_size / 4),
715
+ num_landmarks=self.num_landmarks, scale=self.scale)
716
+
717
+ map_per_channel = map_comapre_channels(
718
+ self.valid_images_loaded[:self.sample_grid].copy(), batch_maps_pred,
719
+ self.valid_gt_maps_loaded, image_size=self.image_size,
720
+ num_landmarks=self.num_landmarks, scale=self.scale)
721
+
722
+ summary_img = sess.run(
723
+ self.img_summary_valid,
724
+ {self.log_image_map: np.expand_dims(merged_img, 0),
725
+ self.log_map_channels: np.expand_dims(map_per_channel, 0),
726
+ self.log_image_map_small: np.expand_dims(merged_img_small, 0),
727
+ self.log_map_channels_small: np.expand_dims(map_per_channel_small, 0)})
728
+ else:
729
+ summary_img = sess.run(
730
+ self.img_summary_valid,
731
+ {self.log_image_map: np.expand_dims(merged_img, 0),
732
+ self.log_image_map_small: np.expand_dims(merged_img_small, 0)})
733
+
734
+ summary_writer.add_summary(summary_img, step)
735
+ else: # save heatmap images to directory
736
+ sample_path_imgs = os.path.join(
737
+ self.save_sample_path, 'epoch-%d-train-iter-%d-1.png' % (epoch, step + 1))
738
+ sample_path_imgs_small = os.path.join(
739
+ self.save_sample_path, 'epoch-%d-train-iter-%d-1-s.png' % (epoch, step + 1))
740
+ scipy.misc.imsave(sample_path_imgs, merged_img)
741
+ scipy.misc.imsave(sample_path_imgs_small, merged_img_small)
742
+
743
+ if self.sample_per_channel:
744
+ sample_path_ch_maps = os.path.join(
745
+ self.save_sample_path, 'epoch-%d-train-iter-%d-3.png' % (epoch, step + 1))
746
+ sample_path_ch_maps_small = os.path.join(
747
+ self.save_sample_path, 'epoch-%d-train-iter-%d-3-s.png' % (epoch, step + 1))
748
+ scipy.misc.imsave(sample_path_ch_maps, map_per_channel)
749
+ scipy.misc.imsave(sample_path_ch_maps_small, map_per_channel_small)
750
+
751
+ print('*** Finished Training ***')
752
+
753
+ def get_image_maps(self, test_image, reuse=None, norm=False):
754
+ """ returns heatmaps of input image (menpo image object)"""
755
+
756
+ self.add_placeholders()
757
+ # build model
758
+ pred_hm_p, pred_hm_f, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse)
759
+
760
+ with tf.Session(config=self.config) as sess:
761
+ # load trained parameters
762
+ saver = tf.train.Saver()
763
+ saver.restore(sess, self.test_model_path)
764
+ _, model_name = os.path.split(self.test_model_path)
765
+
766
+ test_image = test_image.pixels_with_channels_at_back().astype('float32')
767
+ if norm:
768
+ if self.scale is '255':
769
+ test_image *= 255
770
+ elif self.scale is '0':
771
+ test_image = 2 * test_image - 1
772
+
773
+ map_primary, map_fusion, map_upsample = sess.run(
774
+ [pred_hm_p, pred_hm_f, pred_hm_u], {self.images: np.expand_dims(test_image, 0)})
775
+
776
+ return map_primary, map_fusion, map_upsample
777
+
778
+ def get_landmark_predictions(self, img_list, pdm_models_dir, clm_model_path, reuse=None, map_to_input_size=False):
779
+
780
+ """returns dictionary with landmark predictions of each step of the ECpTp algorithm and ECT"""
781
+
782
+ from thirdparty.face_of_art.pdm_clm_functions import feature_based_pdm_corr, clm_correct
783
+
784
+ jaw_line_inds = np.arange(0, 17)
785
+ left_brow_inds = np.arange(17, 22)
786
+ right_brow_inds = np.arange(22, 27)
787
+
788
+ self.add_placeholders()
789
+ # build model
790
+ _, _, pred_hm_u = self.heatmaps_network(self.images, reuse=reuse)
791
+
792
+ with tf.Session(config=self.config) as sess:
793
+ # load trained parameters
794
+ saver = tf.train.Saver()
795
+ saver.restore(sess, self.test_model_path)
796
+ _, model_name = os.path.split(self.test_model_path)
797
+ e_list = []
798
+ ect_list = []
799
+ ecp_list = []
800
+ ecpt_list = []
801
+ ecptp_jaw_list = []
802
+ ecptp_out_list = []
803
+
804
+ for test_image in img_list:
805
+
806
+ if map_to_input_size:
807
+ test_image_transform = test_image[1]
808
+ test_image=test_image[0]
809
+
810
+ # get landmarks for estimation stage
811
+ if test_image.n_channels < 3:
812
+ test_image_map = sess.run(
813
+ pred_hm_u, {self.images: np.expand_dims(
814
+ gray2rgb(test_image.pixels_with_channels_at_back()).astype('float32'), 0)})
815
+ else:
816
+ test_image_map = sess.run(
817
+ pred_hm_u, {self.images: np.expand_dims(
818
+ test_image.pixels_with_channels_at_back().astype('float32'), 0)})
819
+ init_lms = heat_maps_to_landmarks(np.squeeze(test_image_map))
820
+
821
+ # get landmarks for part-based correction stage
822
+ p_pdm_lms = feature_based_pdm_corr(lms_init=init_lms, models_dir=pdm_models_dir, train_type='basic')
823
+
824
+ # get landmarks for part-based tuning stage
825
+ try: # clm may not converge
826
+ pdm_clm_lms = clm_correct(
827
+ clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=p_pdm_lms)
828
+ except:
829
+ pdm_clm_lms = p_pdm_lms.copy()
830
+
831
+ # get landmarks ECT
832
+ try: # clm may not converge
833
+ ect_lms = clm_correct(
834
+ clm_model_path=clm_model_path, image=test_image, map=test_image_map, lms_init=init_lms)
835
+ except:
836
+ ect_lms = p_pdm_lms.copy()
837
+
838
+ # get landmarks for ECpTp_out (tune jaw and eyebrows)
839
+ ecptp_out = p_pdm_lms.copy()
840
+ ecptp_out[left_brow_inds] = pdm_clm_lms[left_brow_inds]
841
+ ecptp_out[right_brow_inds] = pdm_clm_lms[right_brow_inds]
842
+ ecptp_out[jaw_line_inds] = pdm_clm_lms[jaw_line_inds]
843
+
844
+ # get landmarks for ECpTp_jaw (tune jaw)
845
+ ecptp_jaw = p_pdm_lms.copy()
846
+ ecptp_jaw[jaw_line_inds] = pdm_clm_lms[jaw_line_inds]
847
+
848
+ if map_to_input_size:
849
+ ecptp_jaw = test_image_transform.apply(ecptp_jaw)
850
+ ecptp_out = test_image_transform.apply(ecptp_out)
851
+ ect_lms = test_image_transform.apply(ect_lms)
852
+ init_lms = test_image_transform.apply(init_lms)
853
+ p_pdm_lms = test_image_transform.apply(p_pdm_lms)
854
+ pdm_clm_lms = test_image_transform.apply(pdm_clm_lms)
855
+
856
+ ecptp_jaw_list.append(ecptp_jaw) # E + p-correction + p-tuning (ECpTp_jaw)
857
+ ecptp_out_list.append(ecptp_out) # E + p-correction + p-tuning (ECpTp_out)
858
+ ect_list.append(ect_lms) # ECT prediction
859
+ e_list.append(init_lms) # init prediction from heatmap network (E)
860
+ ecp_list.append(p_pdm_lms) # init prediction + part pdm correction (ECp)
861
+ ecpt_list.append(pdm_clm_lms) # init prediction + part pdm correction + global tuning (ECpT)
862
+
863
+ pred_dict = {
864
+ 'E': e_list,
865
+ 'ECp': ecp_list,
866
+ 'ECpT': ecpt_list,
867
+ 'ECT': ect_list,
868
+ 'ECpTp_jaw': ecptp_jaw_list,
869
+ 'ECpTp_out': ecptp_out_list
870
+ }
871
+
872
+ return pred_dict