fffiloni commited on
Commit
e394497
1 Parent(s): 674827a

Upload 25 files

Browse files
LICENSE ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making MimicMotion available.
2
+
3
+ Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
4
+
5
+ MimicMotion is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ --------------------------------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
36
+
37
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software Licensed under the Apache License Version 2.0:
73
+ The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2023 THL A29 Limited.
74
+ --------------------------------------------------------------------
75
+ 1. diffusers
76
+ Copyright (c) diffusers original author and authors
77
+
78
+ 2. DWPose
79
+ Copyright 2018-2020 Open-MMLab.
80
+ Please note this software has been modified by Tencent in this distribution.
81
+
82
+ 3. transformers
83
+ Copyright (c) transformers original author and authors
84
+
85
+ 4. decord
86
+ Copyright (c) DWPoseoriginal author and authors
87
+
88
+
89
+ A copy of Apache 2.0 has been included in this file.
90
+
91
+
92
+
93
+ Open Source Software Licensed under the BSD 3-Clause License:
94
+ --------------------------------------------------------------------
95
+ 1. torch
96
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
97
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
98
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
99
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
100
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
101
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
102
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
103
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
104
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
105
+
106
+ 2. omegaconf
107
+ Copyright (c) 2018, Omry Yadan
108
+ All rights reserved.
109
+
110
+ 3. torchvision
111
+ Copyright (c) Soumith Chintala 2016,
112
+ All rights reserved.
113
+
114
+
115
+ Terms of the BSD 3-Clause:
116
+ --------------------------------------------------------------------
117
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
118
+
119
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
120
+
121
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
122
+
123
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
124
+
125
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
126
+
127
+
128
+
129
+ Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
130
+ --------------------------------------------------------------------
131
+ 1. numpy
132
+ Copyright (c) 2005-2023, NumPy Developers.
133
+ All rights reserved.
134
+
135
+ A copy of the BSD 3-Clause is included in this file.
136
+
137
+ For the license of other third party components, please refer to the following URL:
138
+ https://github.com/numpy/numpy/blob/v1.26.3/LICENSES_bundled.txt
139
+
140
+
141
+
142
+ Open Source Software Licensed under the HPND License:
143
+ --------------------------------------------------------------------
144
+ 1. Pillow
145
+ Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors.
146
+
147
+
148
+ Terms of the HPND License:
149
+ --------------------------------------------------------------------
150
+ The Python Imaging Library (PIL) is
151
+
152
+ Copyright © 1997-2011 by Secret Labs AB
153
+ Copyright © 1995-2011 by Fredrik Lundh
154
+
155
+ Pillow is the friendly PIL fork. It is
156
+
157
+ Copyright © 2010-2023 by Jeffrey A. Clark (Alex) and contributors.
158
+
159
+ Like PIL, Pillow is licensed under the open source HPND License:
160
+
161
+ By obtaining, using, and/or copying this software and/or its associated
162
+ documentation, you agree that you have read, understood, and will comply
163
+ with the following terms and conditions:
164
+
165
+ Permission to use, copy, modify and distribute this software and its
166
+ documentation for any purpose and without fee is hereby granted,
167
+ provided that the above copyright notice appears in all copies, and that
168
+ both that copyright notice and this permission notice appear in supporting
169
+ documentation, and that the name of Secret Labs AB or the author not be
170
+ used in advertising or publicity pertaining to distribution of the software
171
+ without specific, written prior permission.
172
+
173
+ SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
174
+ SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS.
175
+ IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL,
176
+ INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
177
+ LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
178
+ OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
179
+ PERFORMANCE OF THIS SOFTWARE.
180
+
181
+
182
+ Open Source Software Licensed under the Matplotlib License and Other Licenses of the Third-Party Components therein:
183
+ --------------------------------------------------------------------
184
+ 1. matplotlib
185
+ Copyright (c)
186
+ 2012- Matplotlib Development Team; All Rights Reserved
187
+
188
+
189
+ Terms of the Matplotlib License:
190
+ --------------------------------------------------------------------
191
+ License agreement for matplotlib versions 1.3.0 and later
192
+ =========================================================
193
+
194
+ 1. This LICENSE AGREEMENT is between the Matplotlib Development Team
195
+ ("MDT"), and the Individual or Organization ("Licensee") accessing and
196
+ otherwise using matplotlib software in source or binary form and its
197
+ associated documentation.
198
+
199
+ 2. Subject to the terms and conditions of this License Agreement, MDT
200
+ hereby grants Licensee a nonexclusive, royalty-free, world-wide license
201
+ to reproduce, analyze, test, perform and/or display publicly, prepare
202
+ derivative works, distribute, and otherwise use matplotlib
203
+ alone or in any derivative version, provided, however, that MDT's
204
+ License Agreement and MDT's notice of copyright, i.e., "Copyright (c)
205
+ 2012- Matplotlib Development Team; All Rights Reserved" are retained in
206
+ matplotlib alone or in any derivative version prepared by
207
+ Licensee.
208
+
209
+ 3. In the event Licensee prepares a derivative work that is based on or
210
+ incorporates matplotlib or any part thereof, and wants to
211
+ make the derivative work available to others as provided herein, then
212
+ Licensee hereby agrees to include in any such work a brief summary of
213
+ the changes made to matplotlib .
214
+
215
+ 4. MDT is making matplotlib available to Licensee on an "AS
216
+ IS" basis. MDT MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
217
+ IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, MDT MAKES NO AND
218
+ DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
219
+ FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB
220
+ WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
221
+
222
+ 5. MDT SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB
223
+ FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR
224
+ LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING
225
+ MATPLOTLIB , OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF
226
+ THE POSSIBILITY THEREOF.
227
+
228
+ 6. This License Agreement will automatically terminate upon a material
229
+ breach of its terms and conditions.
230
+
231
+ 7. Nothing in this License Agreement shall be deemed to create any
232
+ relationship of agency, partnership, or joint venture between MDT and
233
+ Licensee. This License Agreement does not grant permission to use MDT
234
+ trademarks or trade name in a trademark sense to endorse or promote
235
+ products or services of Licensee, or any third party.
236
+
237
+ 8. By copying, installing or otherwise using matplotlib ,
238
+ Licensee agrees to be bound by the terms and conditions of this License
239
+ Agreement.
240
+
241
+ License agreement for matplotlib versions prior to 1.3.0
242
+ ========================================================
243
+
244
+ 1. This LICENSE AGREEMENT is between John D. Hunter ("JDH"), and the
245
+ Individual or Organization ("Licensee") accessing and otherwise using
246
+ matplotlib software in source or binary form and its associated
247
+ documentation.
248
+
249
+ 2. Subject to the terms and conditions of this License Agreement, JDH
250
+ hereby grants Licensee a nonexclusive, royalty-free, world-wide license
251
+ to reproduce, analyze, test, perform and/or display publicly, prepare
252
+ derivative works, distribute, and otherwise use matplotlib
253
+ alone or in any derivative version, provided, however, that JDH's
254
+ License Agreement and JDH's notice of copyright, i.e., "Copyright (c)
255
+ 2002-2011 John D. Hunter; All Rights Reserved" are retained in
256
+ matplotlib alone or in any derivative version prepared by
257
+ Licensee.
258
+
259
+ 3. In the event Licensee prepares a derivative work that is based on or
260
+ incorporates matplotlib or any part thereof, and wants to
261
+ make the derivative work available to others as provided herein, then
262
+ Licensee hereby agrees to include in any such work a brief summary of
263
+ the changes made to matplotlib.
264
+
265
+ 4. JDH is making matplotlib available to Licensee on an "AS
266
+ IS" basis. JDH MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR
267
+ IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, JDH MAKES NO AND
268
+ DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS
269
+ FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF MATPLOTLIB
270
+ WILL NOT INFRINGE ANY THIRD PARTY RIGHTS.
271
+
272
+ 5. JDH SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF MATPLOTLIB
273
+ FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR
274
+ LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING
275
+ MATPLOTLIB , OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF
276
+ THE POSSIBILITY THEREOF.
277
+
278
+ 6. This License Agreement will automatically terminate upon a material
279
+ breach of its terms and conditions.
280
+
281
+ 7. Nothing in this License Agreement shall be deemed to create any
282
+ relationship of agency, partnership, or joint venture between JDH and
283
+ Licensee. This License Agreement does not grant permission to use JDH
284
+ trademarks or trade name in a trademark sense to endorse or promote
285
+ products or services of Licensee, or any third party.
286
+
287
+ 8. By copying, installing or otherwise using matplotlib,
288
+ Licensee agrees to be bound by the terms and conditions of this License
289
+ Agreement.
290
+
291
+ For the license of other third party components, please refer to the following URL:
292
+ https://github.com/matplotlib/matplotlib/tree/v3.8.0/LICENSE
293
+
294
+
295
+ Open Source Software Licensed under the MIT License:
296
+ --------------------------------------------------------------------
297
+ 1. einops
298
+ Copyright (c) 2018 Alex Rogozhnikov
299
+
300
+ 2. onnxruntime
301
+ Copyright (c) Microsoft Corporation
302
+
303
+ 3. OpenCV
304
+ Copyright (c) Olli-Pekka Heinisuo
305
+
306
+
307
+ Terms of the MIT License:
308
+ --------------------------------------------------------------------
309
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
310
+
311
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
312
+
313
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
cog.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for Cog ⚙️
2
+ # Reference: https://cog.run/yaml
3
+
4
+ build:
5
+ # set to true if your model requires a GPU
6
+ gpu: true
7
+ # cuda: "11.7"
8
+
9
+ # a list of ubuntu apt packages to install
10
+ system_packages:
11
+ - "libgl1-mesa-glx"
12
+ - "libglib2.0-0"
13
+
14
+ # python version in the form '3.11' or '3.11.4'
15
+ python_version: "3.11"
16
+
17
+ # a list of packages in the format <package-name>==<version>
18
+ python_packages:
19
+ - "torch>=2.3" # 2.3.1
20
+ - "torchvision>=0.18" # 0.18.1
21
+ - "diffusers>=0.29" # 0.29.2
22
+ - "transformers>=4.42" # 4.42.3
23
+ - "decord>=0.6" # 0.6.0
24
+ - "einops>=0.8" # 0.8.0
25
+ - "omegaconf>=2.3" # 2.3.0
26
+ - "opencv-python>=4.10" # 4.10.0.84
27
+ - "matplotlib>=3.9" # 3.9.1
28
+ - "onnxruntime>=1.18" # 1.18.1
29
+ - "accelerate>=0.32" # 0.32.0
30
+ - "av>=12.2" # 12.2.0, https://github.com/continue-revolution/sd-webui-animatediff/issues/377
31
+
32
+ # commands run after the environment is setup
33
+ run:
34
+ - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
35
+
36
+ # predict.py defines how predictions are run on your model
37
+ predict: "predict.py:Predictor"
configs/test.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base svd model path
2
+ base_model_path: stabilityai/stable-video-diffusion-img2vid-xt-1-1
3
+
4
+ # checkpoint path
5
+ ckpt_path: models/MimicMotion_1-1.pth
6
+
7
+ test_case:
8
+ - ref_video_path: assets/example_data/videos/pose1.mp4
9
+ ref_image_path: assets/example_data/images/demo1.jpg
10
+ num_frames: 72
11
+ resolution: 576
12
+ frames_overlap: 6
13
+ num_inference_steps: 25
14
+ noise_aug_strength: 0
15
+ guidance_scale: 2.0
16
+ sample_stride: 2
17
+ fps: 15
18
+ seed: 42
19
+
20
+
constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # w/h apsect ratio
2
+ ASPECT_RATIO = 9 / 16
environment.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mimicmotion
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.11
7
+ - pytorch=2.0.1
8
+ - torchvision=0.15.2
9
+ - pytorch-cuda=11.7
10
+ - pip
11
+ - pip:
12
+ - diffusers==0.27.0
13
+ - transformers==4.32.1
14
+ - decord==0.6.0
15
+ - einops
16
+ - omegaconf
inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ import math
5
+ from omegaconf import OmegaConf
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+
9
+ import numpy as np
10
+ import torch.jit
11
+ from torchvision.datasets.folder import pil_loader
12
+ from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
13
+ from torchvision.transforms.functional import to_pil_image
14
+
15
+
16
+ from mimicmotion.utils.geglu_patch import patch_geglu_inplace
17
+ patch_geglu_inplace()
18
+
19
+ from constants import ASPECT_RATIO
20
+
21
+ from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
22
+ from mimicmotion.utils.loader import create_pipeline
23
+ from mimicmotion.utils.utils import save_to_mp4
24
+ from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
25
+
26
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s")
27
+ logger = logging.getLogger(__name__)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def preprocess(video_path, image_path, resolution=576, sample_stride=2):
32
+ """preprocess ref image pose and video pose
33
+
34
+ Args:
35
+ video_path (str): input video pose path
36
+ image_path (str): reference image path
37
+ resolution (int, optional): Defaults to 576.
38
+ sample_stride (int, optional): Defaults to 2.
39
+ """
40
+ image_pixels = pil_loader(image_path)
41
+ image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
42
+ h, w = image_pixels.shape[-2:]
43
+ ############################ compute target h/w according to original aspect ratio ###############################
44
+ if h>w:
45
+ w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
46
+ else:
47
+ w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
48
+ h_w_ratio = float(h) / float(w)
49
+ if h_w_ratio < h_target / w_target:
50
+ h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio)
51
+ else:
52
+ h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target
53
+ image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
54
+ image_pixels = center_crop(image_pixels, [h_target, w_target])
55
+ image_pixels = image_pixels.permute((1, 2, 0)).numpy()
56
+ ##################################### get image&video pose value #################################################
57
+ image_pose = get_image_pose(image_pixels)
58
+ video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride)
59
+ pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
60
+ image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
61
+ return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1
62
+
63
+
64
+ def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config):
65
+ image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5]
66
+ generator = torch.Generator(device=device)
67
+ generator.manual_seed(task_config.seed)
68
+ frames = pipeline(
69
+ image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(0),
70
+ tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap,
71
+ height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7,
72
+ noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps,
73
+ generator=generator, min_guidance_scale=task_config.guidance_scale,
74
+ max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device
75
+ ).frames.cpu()
76
+ video_frames = (frames * 255.0).to(torch.uint8)
77
+
78
+ for vid_idx in range(video_frames.shape[0]):
79
+ # deprecated first frame because of ref image
80
+ _video_frames = video_frames[vid_idx, 1:]
81
+
82
+ return _video_frames
83
+
84
+
85
+ @torch.no_grad()
86
+ def main(args):
87
+ if not args.no_use_float16 :
88
+ torch.set_default_dtype(torch.float16)
89
+
90
+ infer_config = OmegaConf.load(args.inference_config)
91
+ pipeline = create_pipeline(infer_config, device)
92
+
93
+ for task in infer_config.test_case:
94
+ ############################################## Pre-process data ##############################################
95
+ pose_pixels, image_pixels = preprocess(
96
+ task.ref_video_path, task.ref_image_path,
97
+ resolution=task.resolution, sample_stride=task.sample_stride
98
+ )
99
+ ########################################### Run MimicMotion pipeline ###########################################
100
+ _video_frames = run_pipeline(
101
+ pipeline,
102
+ image_pixels, pose_pixels,
103
+ device, task
104
+ )
105
+ ################################### save results to output folder. ###########################################
106
+ save_to_mp4(
107
+ _video_frames,
108
+ f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}" \
109
+ f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4",
110
+ fps=task.fps,
111
+ )
112
+
113
+ def set_logger(log_file=None, log_level=logging.INFO):
114
+ log_handler = logging.FileHandler(log_file, "w")
115
+ log_handler.setFormatter(
116
+ logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
117
+ )
118
+ log_handler.setLevel(log_level)
119
+ logger.addHandler(log_handler)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--log_file", type=str, default=None)
125
+ parser.add_argument("--inference_config", type=str, default="configs/test.yaml") #ToDo
126
+ parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output")
127
+ parser.add_argument("--no_use_float16",
128
+ action="store_true",
129
+ help="Whether use float16 to speed up inference",
130
+ )
131
+ args = parser.parse_args()
132
+
133
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
134
+ set_logger(args.log_file \
135
+ if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log")
136
+ main(args)
137
+ logger.info(f"--- Finished ---")
138
+
mimicmotion/__init__.py ADDED
File without changes
mimicmotion/dwpose/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
mimicmotion/dwpose/__init__.py ADDED
File without changes
mimicmotion/dwpose/dwpose_detector.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .wholebody import Wholebody
7
+
8
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ class DWposeDetector:
12
+ """
13
+ A pose detect method for image-like data.
14
+
15
+ Parameters:
16
+ model_det: (str) serialized ONNX format model path,
17
+ such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx
18
+ model_pose: (str) serialized ONNX format model path,
19
+ such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
20
+ device: (str) 'cpu' or 'cuda:{device_id}'
21
+ """
22
+ def __init__(self, model_det, model_pose, device='cpu'):
23
+ self.args = model_det, model_pose, device
24
+
25
+ def release_memory(self):
26
+ if hasattr(self, 'pose_estimation'):
27
+ del self.pose_estimation
28
+ import gc; gc.collect()
29
+
30
+ def __call__(self, oriImg):
31
+ if not hasattr(self, 'pose_estimation'):
32
+ self.pose_estimation = Wholebody(*self.args)
33
+
34
+ oriImg = oriImg.copy()
35
+ H, W, C = oriImg.shape
36
+ with torch.no_grad():
37
+ candidate, score = self.pose_estimation(oriImg)
38
+ nums, _, locs = candidate.shape
39
+ candidate[..., 0] /= float(W)
40
+ candidate[..., 1] /= float(H)
41
+ body = candidate[:, :18].copy()
42
+ body = body.reshape(nums * 18, locs)
43
+ subset = score[:, :18].copy()
44
+ for i in range(len(subset)):
45
+ for j in range(len(subset[i])):
46
+ if subset[i][j] > 0.3:
47
+ subset[i][j] = int(18 * i + j)
48
+ else:
49
+ subset[i][j] = -1
50
+
51
+ # un_visible = subset < 0.3
52
+ # candidate[un_visible] = -1
53
+
54
+ # foot = candidate[:, 18:24]
55
+
56
+ faces = candidate[:, 24:92]
57
+
58
+ hands = candidate[:, 92:113]
59
+ hands = np.vstack([hands, candidate[:, 113:]])
60
+
61
+ faces_score = score[:, 24:92]
62
+ hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
63
+
64
+ bodies = dict(candidate=body, subset=subset, score=score[:, :18])
65
+ pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
66
+
67
+ return pose
68
+
69
+ dwpose_detector = DWposeDetector(
70
+ model_det="models/DWPose/yolox_l.onnx",
71
+ model_pose="models/DWPose/dw-ll_ucoco_384.onnx",
72
+ device=device)
mimicmotion/dwpose/onnxdet.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def nms(boxes, scores, nms_thr):
6
+ """Single class NMS implemented in Numpy.
7
+
8
+ Args:
9
+ boxes (np.ndarray): shape=(N,4); N is number of boxes
10
+ scores (np.ndarray): the score of bboxes
11
+ nms_thr (float): the threshold in NMS
12
+
13
+ Returns:
14
+ List[int]: output bbox ids
15
+ """
16
+ x1 = boxes[:, 0]
17
+ y1 = boxes[:, 1]
18
+ x2 = boxes[:, 2]
19
+ y2 = boxes[:, 3]
20
+
21
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
22
+ order = scores.argsort()[::-1]
23
+
24
+ keep = []
25
+ while order.size > 0:
26
+ i = order[0]
27
+ keep.append(i)
28
+ xx1 = np.maximum(x1[i], x1[order[1:]])
29
+ yy1 = np.maximum(y1[i], y1[order[1:]])
30
+ xx2 = np.minimum(x2[i], x2[order[1:]])
31
+ yy2 = np.minimum(y2[i], y2[order[1:]])
32
+
33
+ w = np.maximum(0.0, xx2 - xx1 + 1)
34
+ h = np.maximum(0.0, yy2 - yy1 + 1)
35
+ inter = w * h
36
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
37
+
38
+ inds = np.where(ovr <= nms_thr)[0]
39
+ order = order[inds + 1]
40
+
41
+ return keep
42
+
43
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
44
+ """Multiclass NMS implemented in Numpy. Class-aware version.
45
+
46
+ Args:
47
+ boxes (np.ndarray): shape=(N,4); N is number of boxes
48
+ scores (np.ndarray): the score of bboxes
49
+ nms_thr (float): the threshold in NMS
50
+ score_thr (float): the threshold of cls score
51
+
52
+ Returns:
53
+ np.ndarray: outputs bboxes coordinate
54
+ """
55
+ final_dets = []
56
+ num_classes = scores.shape[1]
57
+ for cls_ind in range(num_classes):
58
+ cls_scores = scores[:, cls_ind]
59
+ valid_score_mask = cls_scores > score_thr
60
+ if valid_score_mask.sum() == 0:
61
+ continue
62
+ else:
63
+ valid_scores = cls_scores[valid_score_mask]
64
+ valid_boxes = boxes[valid_score_mask]
65
+ keep = nms(valid_boxes, valid_scores, nms_thr)
66
+ if len(keep) > 0:
67
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
68
+ dets = np.concatenate(
69
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
70
+ )
71
+ final_dets.append(dets)
72
+ if len(final_dets) == 0:
73
+ return None
74
+ return np.concatenate(final_dets, 0)
75
+
76
+ def demo_postprocess(outputs, img_size, p6=False):
77
+ grids = []
78
+ expanded_strides = []
79
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
80
+
81
+ hsizes = [img_size[0] // stride for stride in strides]
82
+ wsizes = [img_size[1] // stride for stride in strides]
83
+
84
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
85
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
86
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
87
+ grids.append(grid)
88
+ shape = grid.shape[:2]
89
+ expanded_strides.append(np.full((*shape, 1), stride))
90
+
91
+ grids = np.concatenate(grids, 1)
92
+ expanded_strides = np.concatenate(expanded_strides, 1)
93
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
94
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
95
+
96
+ return outputs
97
+
98
+ def preprocess(img, input_size, swap=(2, 0, 1)):
99
+ if len(img.shape) == 3:
100
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
101
+ else:
102
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
103
+
104
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
105
+ resized_img = cv2.resize(
106
+ img,
107
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
108
+ interpolation=cv2.INTER_LINEAR,
109
+ ).astype(np.uint8)
110
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
111
+
112
+ padded_img = padded_img.transpose(swap)
113
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
114
+ return padded_img, r
115
+
116
+ def inference_detector(session, oriImg):
117
+ """run human detect
118
+ """
119
+ input_shape = (640,640)
120
+ img, ratio = preprocess(oriImg, input_shape)
121
+
122
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
123
+ output = session.run(None, ort_inputs)
124
+ predictions = demo_postprocess(output[0], input_shape)[0]
125
+
126
+ boxes = predictions[:, :4]
127
+ scores = predictions[:, 4:5] * predictions[:, 5:]
128
+
129
+ boxes_xyxy = np.ones_like(boxes)
130
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
131
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
132
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
133
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
134
+ boxes_xyxy /= ratio
135
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
136
+ if dets is not None:
137
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
138
+ isscore = final_scores>0.3
139
+ iscat = final_cls_inds == 0
140
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
141
+ final_boxes = final_boxes[isbbox]
142
+ else:
143
+ final_boxes = np.array([])
144
+
145
+ return final_boxes
mimicmotion/dwpose/onnxpose.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+
7
+ def preprocess(
8
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
9
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
10
+ """Do preprocessing for RTMPose model inference.
11
+
12
+ Args:
13
+ img (np.ndarray): Input image in shape.
14
+ input_size (tuple): Input image size in shape (w, h).
15
+
16
+ Returns:
17
+ tuple:
18
+ - resized_img (np.ndarray): Preprocessed image.
19
+ - center (np.ndarray): Center of image.
20
+ - scale (np.ndarray): Scale of image.
21
+ """
22
+ # get shape of image
23
+ img_shape = img.shape[:2]
24
+ out_img, out_center, out_scale = [], [], []
25
+ if len(out_bbox) == 0:
26
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
27
+ for i in range(len(out_bbox)):
28
+ x0 = out_bbox[i][0]
29
+ y0 = out_bbox[i][1]
30
+ x1 = out_bbox[i][2]
31
+ y1 = out_bbox[i][3]
32
+ bbox = np.array([x0, y0, x1, y1])
33
+
34
+ # get center and scale
35
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
36
+
37
+ # do affine transformation
38
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
39
+
40
+ # normalize image
41
+ mean = np.array([123.675, 116.28, 103.53])
42
+ std = np.array([58.395, 57.12, 57.375])
43
+ resized_img = (resized_img - mean) / std
44
+
45
+ out_img.append(resized_img)
46
+ out_center.append(center)
47
+ out_scale.append(scale)
48
+
49
+ return out_img, out_center, out_scale
50
+
51
+
52
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
53
+ """Inference RTMPose model.
54
+
55
+ Args:
56
+ sess (ort.InferenceSession): ONNXRuntime session.
57
+ img (np.ndarray): Input image in shape.
58
+
59
+ Returns:
60
+ outputs (np.ndarray): Output of RTMPose model.
61
+ """
62
+ all_out = []
63
+ # build input
64
+ for i in range(len(img)):
65
+ input = [img[i].transpose(2, 0, 1)]
66
+
67
+ # build output
68
+ sess_input = {sess.get_inputs()[0].name: input}
69
+ sess_output = []
70
+ for out in sess.get_outputs():
71
+ sess_output.append(out.name)
72
+
73
+ # run model
74
+ outputs = sess.run(sess_output, sess_input)
75
+ all_out.append(outputs)
76
+
77
+ return all_out
78
+
79
+
80
+ def postprocess(outputs: List[np.ndarray],
81
+ model_input_size: Tuple[int, int],
82
+ center: Tuple[int, int],
83
+ scale: Tuple[int, int],
84
+ simcc_split_ratio: float = 2.0
85
+ ) -> Tuple[np.ndarray, np.ndarray]:
86
+ """Postprocess for RTMPose model output.
87
+
88
+ Args:
89
+ outputs (np.ndarray): Output of RTMPose model.
90
+ model_input_size (tuple): RTMPose model Input image size.
91
+ center (tuple): Center of bbox in shape (x, y).
92
+ scale (tuple): Scale of bbox in shape (w, h).
93
+ simcc_split_ratio (float): Split ratio of simcc.
94
+
95
+ Returns:
96
+ tuple:
97
+ - keypoints (np.ndarray): Rescaled keypoints.
98
+ - scores (np.ndarray): Model predict scores.
99
+ """
100
+ all_key = []
101
+ all_score = []
102
+ for i in range(len(outputs)):
103
+ # use simcc to decode
104
+ simcc_x, simcc_y = outputs[i]
105
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
106
+
107
+ # rescale keypoints
108
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
109
+ all_key.append(keypoints[0])
110
+ all_score.append(scores[0])
111
+
112
+ return np.array(all_key), np.array(all_score)
113
+
114
+
115
+ def bbox_xyxy2cs(bbox: np.ndarray,
116
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
117
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
118
+
119
+ Args:
120
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
121
+ as (left, top, right, bottom)
122
+ padding (float): BBox padding factor that will be multilied to scale.
123
+ Default: 1.0
124
+
125
+ Returns:
126
+ tuple: A tuple containing center and scale.
127
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
128
+ (n, 2)
129
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ """
132
+ # convert single bbox from (4, ) to (1, 4)
133
+ dim = bbox.ndim
134
+ if dim == 1:
135
+ bbox = bbox[None, :]
136
+
137
+ # get bbox center and scale
138
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
139
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
140
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
141
+
142
+ if dim == 1:
143
+ center = center[0]
144
+ scale = scale[0]
145
+
146
+ return center, scale
147
+
148
+
149
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
150
+ aspect_ratio: float) -> np.ndarray:
151
+ """Extend the scale to match the given aspect ratio.
152
+
153
+ Args:
154
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
155
+ aspect_ratio (float): The ratio of ``w/h``
156
+
157
+ Returns:
158
+ np.ndarray: The reshaped image scale in (2, )
159
+ """
160
+ w, h = np.hsplit(bbox_scale, [1])
161
+ bbox_scale = np.where(w > h * aspect_ratio,
162
+ np.hstack([w, w / aspect_ratio]),
163
+ np.hstack([h * aspect_ratio, h]))
164
+ return bbox_scale
165
+
166
+
167
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
168
+ """Rotate a point by an angle.
169
+
170
+ Args:
171
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
172
+ angle_rad (float): rotation angle in radian
173
+
174
+ Returns:
175
+ np.ndarray: Rotated point in shape (2, )
176
+ """
177
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
178
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
179
+ return rot_mat @ pt
180
+
181
+
182
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
183
+ """To calculate the affine matrix, three pairs of points are required. This
184
+ function is used to get the 3rd point, given 2D points a & b.
185
+
186
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
187
+ anticlockwise, using b as the rotation center.
188
+
189
+ Args:
190
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
191
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
192
+
193
+ Returns:
194
+ np.ndarray: The 3rd point.
195
+ """
196
+ direction = a - b
197
+ c = b + np.r_[-direction[1], direction[0]]
198
+ return c
199
+
200
+
201
+ def get_warp_matrix(center: np.ndarray,
202
+ scale: np.ndarray,
203
+ rot: float,
204
+ output_size: Tuple[int, int],
205
+ shift: Tuple[float, float] = (0., 0.),
206
+ inv: bool = False) -> np.ndarray:
207
+ """Calculate the affine transformation matrix that can warp the bbox area
208
+ in the input image to the output size.
209
+
210
+ Args:
211
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
212
+ scale (np.ndarray[2, ]): Scale of the bounding box
213
+ wrt [width, height].
214
+ rot (float): Rotation angle (degree).
215
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
216
+ destination heatmaps.
217
+ shift (0-100%): Shift translation ratio wrt the width/height.
218
+ Default (0., 0.).
219
+ inv (bool): Option to inverse the affine transform direction.
220
+ (inv=False: src->dst or inv=True: dst->src)
221
+
222
+ Returns:
223
+ np.ndarray: A 2x3 transformation matrix
224
+ """
225
+ shift = np.array(shift)
226
+ src_w = scale[0]
227
+ dst_w = output_size[0]
228
+ dst_h = output_size[1]
229
+
230
+ # compute transformation matrix
231
+ rot_rad = np.deg2rad(rot)
232
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
233
+ dst_dir = np.array([0., dst_w * -0.5])
234
+
235
+ # get four corners of the src rectangle in the original image
236
+ src = np.zeros((3, 2), dtype=np.float32)
237
+ src[0, :] = center + scale * shift
238
+ src[1, :] = center + src_dir + scale * shift
239
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
240
+
241
+ # get four corners of the dst rectangle in the input image
242
+ dst = np.zeros((3, 2), dtype=np.float32)
243
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
244
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
245
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
246
+
247
+ if inv:
248
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
249
+ else:
250
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
251
+
252
+ return warp_mat
253
+
254
+
255
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
256
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
257
+ """Get the bbox image as the model input by affine transform.
258
+
259
+ Args:
260
+ input_size (dict): The input size of the model.
261
+ bbox_scale (dict): The bbox scale of the img.
262
+ bbox_center (dict): The bbox center of the img.
263
+ img (np.ndarray): The original image.
264
+
265
+ Returns:
266
+ tuple: A tuple containing center and scale.
267
+ - np.ndarray[float32]: img after affine transform.
268
+ - np.ndarray[float32]: bbox scale after affine transform.
269
+ """
270
+ w, h = input_size
271
+ warp_size = (int(w), int(h))
272
+
273
+ # reshape bbox to fixed aspect ratio
274
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
275
+
276
+ # get the affine matrix
277
+ center = bbox_center
278
+ scale = bbox_scale
279
+ rot = 0
280
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
281
+
282
+ # do affine transform
283
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
284
+
285
+ return img, bbox_scale
286
+
287
+
288
+ def get_simcc_maximum(simcc_x: np.ndarray,
289
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
290
+ """Get maximum response location and value from simcc representations.
291
+
292
+ Note:
293
+ instance number: N
294
+ num_keypoints: K
295
+ heatmap height: H
296
+ heatmap width: W
297
+
298
+ Args:
299
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
300
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
301
+
302
+ Returns:
303
+ tuple:
304
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
305
+ (K, 2) or (N, K, 2)
306
+ - vals (np.ndarray): values of maximum heatmap responses in shape
307
+ (K,) or (N, K)
308
+ """
309
+ N, K, Wx = simcc_x.shape
310
+ simcc_x = simcc_x.reshape(N * K, -1)
311
+ simcc_y = simcc_y.reshape(N * K, -1)
312
+
313
+ # get maximum value locations
314
+ x_locs = np.argmax(simcc_x, axis=1)
315
+ y_locs = np.argmax(simcc_y, axis=1)
316
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
317
+ max_val_x = np.amax(simcc_x, axis=1)
318
+ max_val_y = np.amax(simcc_y, axis=1)
319
+
320
+ # get maximum value across x and y axis
321
+ mask = max_val_x > max_val_y
322
+ max_val_x[mask] = max_val_y[mask]
323
+ vals = max_val_x
324
+ locs[vals <= 0.] = -1
325
+
326
+ # reshape
327
+ locs = locs.reshape(N, K, 2)
328
+ vals = vals.reshape(N, K)
329
+
330
+ return locs, vals
331
+
332
+
333
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
334
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
335
+ """Modulate simcc distribution with Gaussian.
336
+
337
+ Args:
338
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
339
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
340
+ simcc_split_ratio (int): The split ratio of simcc.
341
+
342
+ Returns:
343
+ tuple: A tuple containing center and scale.
344
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
345
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
346
+ """
347
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
348
+ keypoints /= simcc_split_ratio
349
+
350
+ return keypoints, scores
351
+
352
+
353
+ def inference_pose(session, out_bbox, oriImg):
354
+ """run pose detect
355
+
356
+ Args:
357
+ session (ort.InferenceSession): ONNXRuntime session.
358
+ out_bbox (np.ndarray): bbox list
359
+ oriImg (np.ndarray): Input image in shape.
360
+
361
+ Returns:
362
+ tuple:
363
+ - keypoints (np.ndarray): Rescaled keypoints.
364
+ - scores (np.ndarray): Model predict scores.
365
+ """
366
+ h, w = session.get_inputs()[0].shape[2:]
367
+ model_input_size = (w, h)
368
+ # preprocess for rtm-pose model inference.
369
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
370
+ # run pose estimation for processed img
371
+ outputs = inference(session, resized_img)
372
+ # postprocess for rtm-pose model output.
373
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
374
+
375
+ return keypoints, scores
mimicmotion/dwpose/preprocess.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import decord
3
+ import numpy as np
4
+
5
+ from .util import draw_pose
6
+ from .dwpose_detector import dwpose_detector as dwprocessor
7
+
8
+
9
+ def get_video_pose(
10
+ video_path: str,
11
+ ref_image: np.ndarray,
12
+ sample_stride: int=1):
13
+ """preprocess ref image pose and video pose
14
+
15
+ Args:
16
+ video_path (str): video pose path
17
+ ref_image (np.ndarray): reference image
18
+ sample_stride (int, optional): Defaults to 1.
19
+
20
+ Returns:
21
+ np.ndarray: sequence of video pose
22
+ """
23
+ # select ref-keypoint from reference pose for pose rescale
24
+ ref_pose = dwprocessor(ref_image)
25
+ ref_keypoint_id = [0, 1, 2, 5, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
26
+ ref_keypoint_id = [i for i in ref_keypoint_id \
27
+ if len(ref_pose['bodies']['subset']) > 0 and ref_pose['bodies']['subset'][0][i] >= .0]
28
+ ref_body = ref_pose['bodies']['candidate'][ref_keypoint_id]
29
+
30
+ height, width, _ = ref_image.shape
31
+
32
+ # read input video
33
+ vr = decord.VideoReader(video_path, ctx=decord.cpu(0))
34
+ sample_stride *= max(1, int(vr.get_avg_fps() / 24))
35
+
36
+ frames = vr.get_batch(list(range(0, len(vr), sample_stride))).asnumpy()
37
+ detected_poses = [dwprocessor(frm) for frm in tqdm(frames, desc="DWPose")]
38
+ dwprocessor.release_memory()
39
+
40
+ detected_bodies = np.stack(
41
+ [p['bodies']['candidate'] for p in detected_poses if p['bodies']['candidate'].shape[0] == 18])[:,
42
+ ref_keypoint_id]
43
+ # compute linear-rescale params
44
+ ay, by = np.polyfit(detected_bodies[:, :, 1].flatten(), np.tile(ref_body[:, 1], len(detected_bodies)), 1)
45
+ fh, fw, _ = vr[0].shape
46
+ ax = ay / (fh / fw / height * width)
47
+ bx = np.mean(np.tile(ref_body[:, 0], len(detected_bodies)) - detected_bodies[:, :, 0].flatten() * ax)
48
+ a = np.array([ax, ay])
49
+ b = np.array([bx, by])
50
+ output_pose = []
51
+ # pose rescale
52
+ for detected_pose in detected_poses:
53
+ detected_pose['bodies']['candidate'] = detected_pose['bodies']['candidate'] * a + b
54
+ detected_pose['faces'] = detected_pose['faces'] * a + b
55
+ detected_pose['hands'] = detected_pose['hands'] * a + b
56
+ im = draw_pose(detected_pose, height, width)
57
+ output_pose.append(np.array(im))
58
+ return np.stack(output_pose)
59
+
60
+
61
+ def get_image_pose(ref_image):
62
+ """process image pose
63
+
64
+ Args:
65
+ ref_image (np.ndarray): reference image pixel value
66
+
67
+ Returns:
68
+ np.ndarray: pose visual image in RGB-mode
69
+ """
70
+ height, width, _ = ref_image.shape
71
+ ref_pose = dwprocessor(ref_image)
72
+ pose_img = draw_pose(ref_pose, height, width)
73
+ return np.array(pose_img)
mimicmotion/dwpose/util.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import cv2
5
+
6
+
7
+ eps = 0.01
8
+
9
+ def alpha_blend_color(color, alpha):
10
+ """blend color according to point conf
11
+ """
12
+ return [int(c * alpha) for c in color]
13
+
14
+ def draw_bodypose(canvas, candidate, subset, score):
15
+ H, W, C = canvas.shape
16
+ candidate = np.array(candidate)
17
+ subset = np.array(subset)
18
+
19
+ stickwidth = 4
20
+
21
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
22
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
23
+ [1, 16], [16, 18], [3, 17], [6, 18]]
24
+
25
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
26
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
27
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
28
+
29
+ for i in range(17):
30
+ for n in range(len(subset)):
31
+ index = subset[n][np.array(limbSeq[i]) - 1]
32
+ conf = score[n][np.array(limbSeq[i]) - 1]
33
+ if conf[0] < 0.3 or conf[1] < 0.3:
34
+ continue
35
+ Y = candidate[index.astype(int), 0] * float(W)
36
+ X = candidate[index.astype(int), 1] * float(H)
37
+ mX = np.mean(X)
38
+ mY = np.mean(Y)
39
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
40
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
41
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
42
+ cv2.fillConvexPoly(canvas, polygon, alpha_blend_color(colors[i], conf[0] * conf[1]))
43
+
44
+ canvas = (canvas * 0.6).astype(np.uint8)
45
+
46
+ for i in range(18):
47
+ for n in range(len(subset)):
48
+ index = int(subset[n][i])
49
+ if index == -1:
50
+ continue
51
+ x, y = candidate[index][0:2]
52
+ conf = score[n][i]
53
+ x = int(x * W)
54
+ y = int(y * H)
55
+ cv2.circle(canvas, (int(x), int(y)), 4, alpha_blend_color(colors[i], conf), thickness=-1)
56
+
57
+ return canvas
58
+
59
+ def draw_handpose(canvas, all_hand_peaks, all_hand_scores):
60
+ H, W, C = canvas.shape
61
+
62
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
63
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
64
+
65
+ for peaks, scores in zip(all_hand_peaks, all_hand_scores):
66
+
67
+ for ie, e in enumerate(edges):
68
+ x1, y1 = peaks[e[0]]
69
+ x2, y2 = peaks[e[1]]
70
+ x1 = int(x1 * W)
71
+ y1 = int(y1 * H)
72
+ x2 = int(x2 * W)
73
+ y2 = int(y2 * H)
74
+ score = int(scores[e[0]] * scores[e[1]] * 255)
75
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
76
+ cv2.line(canvas, (x1, y1), (x2, y2),
77
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * score, thickness=2)
78
+
79
+ for i, keyponit in enumerate(peaks):
80
+ x, y = keyponit
81
+ x = int(x * W)
82
+ y = int(y * H)
83
+ score = int(scores[i] * 255)
84
+ if x > eps and y > eps:
85
+ cv2.circle(canvas, (x, y), 4, (0, 0, score), thickness=-1)
86
+ return canvas
87
+
88
+ def draw_facepose(canvas, all_lmks, all_scores):
89
+ H, W, C = canvas.shape
90
+ for lmks, scores in zip(all_lmks, all_scores):
91
+ for lmk, score in zip(lmks, scores):
92
+ x, y = lmk
93
+ x = int(x * W)
94
+ y = int(y * H)
95
+ conf = int(score * 255)
96
+ if x > eps and y > eps:
97
+ cv2.circle(canvas, (x, y), 3, (conf, conf, conf), thickness=-1)
98
+ return canvas
99
+
100
+ def draw_pose(pose, H, W, ref_w=2160):
101
+ """vis dwpose outputs
102
+
103
+ Args:
104
+ pose (List): DWposeDetector outputs in dwpose_detector.py
105
+ H (int): height
106
+ W (int): width
107
+ ref_w (int, optional) Defaults to 2160.
108
+
109
+ Returns:
110
+ np.ndarray: image pixel value in RGB mode
111
+ """
112
+ bodies = pose['bodies']
113
+ faces = pose['faces']
114
+ hands = pose['hands']
115
+ candidate = bodies['candidate']
116
+ subset = bodies['subset']
117
+
118
+ sz = min(H, W)
119
+ sr = (ref_w / sz) if sz != ref_w else 1
120
+
121
+ ########################################## create zero canvas ##################################################
122
+ canvas = np.zeros(shape=(int(H*sr), int(W*sr), 3), dtype=np.uint8)
123
+
124
+ ########################################### draw body pose #####################################################
125
+ canvas = draw_bodypose(canvas, candidate, subset, score=bodies['score'])
126
+
127
+ ########################################### draw hand pose #####################################################
128
+ canvas = draw_handpose(canvas, hands, pose['hands_score'])
129
+
130
+ ########################################### draw face pose #####################################################
131
+ canvas = draw_facepose(canvas, faces, pose['faces_score'])
132
+
133
+ return cv2.cvtColor(cv2.resize(canvas, (W, H)), cv2.COLOR_BGR2RGB).transpose(2, 0, 1)
mimicmotion/dwpose/wholebody.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+
4
+ from .onnxdet import inference_detector
5
+ from .onnxpose import inference_pose
6
+
7
+
8
+ class Wholebody:
9
+ """detect human pose by dwpose
10
+ """
11
+ def __init__(self, model_det, model_pose, device="cpu"):
12
+ providers = ['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider']
13
+ provider_options = None if device == 'cpu' else [{'device_id': 0}]
14
+
15
+ self.session_det = ort.InferenceSession(
16
+ path_or_bytes=model_det, providers=providers, provider_options=provider_options
17
+ )
18
+ self.session_pose = ort.InferenceSession(
19
+ path_or_bytes=model_pose, providers=providers, provider_options=provider_options
20
+ )
21
+
22
+ def __call__(self, oriImg):
23
+ """call to process dwpose-detect
24
+
25
+ Args:
26
+ oriImg (np.ndarray): detected image
27
+
28
+ """
29
+ det_result = inference_detector(self.session_det, oriImg)
30
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
31
+
32
+ keypoints_info = np.concatenate(
33
+ (keypoints, scores[..., None]), axis=-1)
34
+ # compute neck joint
35
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36
+ # neck score when visualizing pred
37
+ neck[:, 2:4] = np.logical_and(
38
+ keypoints_info[:, 5, 2:4] > 0.3,
39
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
40
+ new_keypoints_info = np.insert(
41
+ keypoints_info, 17, neck, axis=1)
42
+ mmpose_idx = [
43
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
44
+ ]
45
+ openpose_idx = [
46
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
47
+ ]
48
+ new_keypoints_info[:, openpose_idx] = \
49
+ new_keypoints_info[:, mmpose_idx]
50
+ keypoints_info = new_keypoints_info
51
+
52
+ keypoints, scores = keypoints_info[
53
+ ..., :2], keypoints_info[..., 2]
54
+
55
+ return keypoints, scores
56
+
57
+
mimicmotion/modules/__init__.py ADDED
File without changes
mimicmotion/modules/attention.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
7
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.resnet import AlphaBlender
10
+ from diffusers.utils import BaseOutput
11
+ from torch import nn
12
+
13
+
14
+ @dataclass
15
+ class TransformerTemporalModelOutput(BaseOutput):
16
+ """
17
+ The output of [`TransformerTemporalModel`].
18
+
19
+ Args:
20
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
21
+ The hidden states output conditioned on `encoder_hidden_states` input.
22
+ """
23
+
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
28
+ """
29
+ A Transformer model for video-like data.
30
+
31
+ Parameters:
32
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
33
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
34
+ in_channels (`int`, *optional*):
35
+ The number of channels in the input and output (specify if the input is **continuous**).
36
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
37
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
38
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
39
+ attention_bias (`bool`, *optional*):
40
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
41
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
42
+ This is fixed during training since it is used to learn a number of position embeddings.
43
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
44
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
45
+ activation functions.
46
+ norm_elementwise_affine (`bool`, *optional*):
47
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
48
+ double_self_attention (`bool`, *optional*):
49
+ Configure if each `TransformerBlock` should contain two self-attention layers.
50
+ positional_embeddings: (`str`, *optional*):
51
+ The type of positional embeddings to apply to the sequence input before passing use.
52
+ num_positional_embeddings: (`int`, *optional*):
53
+ The maximum length of the sequence over which to apply positional embeddings.
54
+ """
55
+
56
+ @register_to_config
57
+ def __init__(
58
+ self,
59
+ num_attention_heads: int = 16,
60
+ attention_head_dim: int = 88,
61
+ in_channels: Optional[int] = None,
62
+ out_channels: Optional[int] = None,
63
+ num_layers: int = 1,
64
+ dropout: float = 0.0,
65
+ norm_num_groups: int = 32,
66
+ cross_attention_dim: Optional[int] = None,
67
+ attention_bias: bool = False,
68
+ sample_size: Optional[int] = None,
69
+ activation_fn: str = "geglu",
70
+ norm_elementwise_affine: bool = True,
71
+ double_self_attention: bool = True,
72
+ positional_embeddings: Optional[str] = None,
73
+ num_positional_embeddings: Optional[int] = None,
74
+ ):
75
+ super().__init__()
76
+ self.num_attention_heads = num_attention_heads
77
+ self.attention_head_dim = attention_head_dim
78
+ inner_dim = num_attention_heads * attention_head_dim
79
+
80
+ self.in_channels = in_channels
81
+
82
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
83
+ self.proj_in = nn.Linear(in_channels, inner_dim)
84
+
85
+ # 3. Define transformers blocks
86
+ self.transformer_blocks = nn.ModuleList(
87
+ [
88
+ BasicTransformerBlock(
89
+ inner_dim,
90
+ num_attention_heads,
91
+ attention_head_dim,
92
+ dropout=dropout,
93
+ cross_attention_dim=cross_attention_dim,
94
+ activation_fn=activation_fn,
95
+ attention_bias=attention_bias,
96
+ double_self_attention=double_self_attention,
97
+ norm_elementwise_affine=norm_elementwise_affine,
98
+ positional_embeddings=positional_embeddings,
99
+ num_positional_embeddings=num_positional_embeddings,
100
+ )
101
+ for d in range(num_layers)
102
+ ]
103
+ )
104
+
105
+ self.proj_out = nn.Linear(inner_dim, in_channels)
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states: torch.FloatTensor,
110
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
111
+ timestep: Optional[torch.LongTensor] = None,
112
+ class_labels: torch.LongTensor = None,
113
+ num_frames: int = 1,
114
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
115
+ return_dict: bool = True,
116
+ ) -> TransformerTemporalModelOutput:
117
+ """
118
+ The [`TransformerTemporal`] forward method.
119
+
120
+ Args:
121
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete,
122
+ `torch.FloatTensor` of shape `(batch size, channel, height, width)`if continuous): Input hidden_states.
123
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125
+ self-attention.
126
+ timestep ( `torch.LongTensor`, *optional*):
127
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130
+ `AdaLayerZeroNorm`.
131
+ num_frames (`int`, *optional*, defaults to 1):
132
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
133
+ cross_attention_kwargs (`dict`, *optional*):
134
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
135
+ `self.processor` in [diffusers.models.attention_processor](
136
+ https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
137
+ return_dict (`bool`, *optional*, defaults to `True`):
138
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
139
+ tuple.
140
+
141
+ Returns:
142
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
143
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
144
+ returned, otherwise a `tuple` where the first element is the sample tensor.
145
+ """
146
+ # 1. Input
147
+ batch_frames, channel, height, width = hidden_states.shape
148
+ batch_size = batch_frames // num_frames
149
+
150
+ residual = hidden_states
151
+
152
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
153
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
154
+
155
+ hidden_states = self.norm(hidden_states)
156
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
157
+
158
+ hidden_states = self.proj_in(hidden_states)
159
+
160
+ # 2. Blocks
161
+ for block in self.transformer_blocks:
162
+ hidden_states = block(
163
+ hidden_states,
164
+ encoder_hidden_states=encoder_hidden_states,
165
+ timestep=timestep,
166
+ cross_attention_kwargs=cross_attention_kwargs,
167
+ class_labels=class_labels,
168
+ )
169
+
170
+ # 3. Output
171
+ hidden_states = self.proj_out(hidden_states)
172
+ hidden_states = (
173
+ hidden_states[None, None, :]
174
+ .reshape(batch_size, height, width, num_frames, channel)
175
+ .permute(0, 3, 4, 1, 2)
176
+ .contiguous()
177
+ )
178
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
179
+
180
+ output = hidden_states + residual
181
+
182
+ if not return_dict:
183
+ return (output,)
184
+
185
+ return TransformerTemporalModelOutput(sample=output)
186
+
187
+
188
+ class TransformerSpatioTemporalModel(nn.Module):
189
+ """
190
+ A Transformer model for video-like data.
191
+
192
+ Parameters:
193
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
194
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
195
+ in_channels (`int`, *optional*):
196
+ The number of channels in the input and output (specify if the input is **continuous**).
197
+ out_channels (`int`, *optional*):
198
+ The number of channels in the output (specify if the input is **continuous**).
199
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
200
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
201
+ """
202
+
203
+ def __init__(
204
+ self,
205
+ num_attention_heads: int = 16,
206
+ attention_head_dim: int = 88,
207
+ in_channels: int = 320,
208
+ out_channels: Optional[int] = None,
209
+ num_layers: int = 1,
210
+ cross_attention_dim: Optional[int] = None,
211
+ ):
212
+ super().__init__()
213
+ self.num_attention_heads = num_attention_heads
214
+ self.attention_head_dim = attention_head_dim
215
+
216
+ inner_dim = num_attention_heads * attention_head_dim
217
+ self.inner_dim = inner_dim
218
+
219
+ # 2. Define input layers
220
+ self.in_channels = in_channels
221
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
222
+ self.proj_in = nn.Linear(in_channels, inner_dim)
223
+
224
+ # 3. Define transformers blocks
225
+ self.transformer_blocks = nn.ModuleList(
226
+ [
227
+ BasicTransformerBlock(
228
+ inner_dim,
229
+ num_attention_heads,
230
+ attention_head_dim,
231
+ cross_attention_dim=cross_attention_dim,
232
+ )
233
+ for d in range(num_layers)
234
+ ]
235
+ )
236
+
237
+ time_mix_inner_dim = inner_dim
238
+ self.temporal_transformer_blocks = nn.ModuleList(
239
+ [
240
+ TemporalBasicTransformerBlock(
241
+ inner_dim,
242
+ time_mix_inner_dim,
243
+ num_attention_heads,
244
+ attention_head_dim,
245
+ cross_attention_dim=cross_attention_dim,
246
+ )
247
+ for _ in range(num_layers)
248
+ ]
249
+ )
250
+
251
+ time_embed_dim = in_channels * 4
252
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
253
+ self.time_proj = Timesteps(in_channels, True, 0)
254
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
255
+
256
+ # 4. Define output layers
257
+ self.out_channels = in_channels if out_channels is None else out_channels
258
+ # TODO: should use out_channels for continuous projections
259
+ self.proj_out = nn.Linear(inner_dim, in_channels)
260
+
261
+ self.gradient_checkpointing = False
262
+
263
+ def forward(
264
+ self,
265
+ hidden_states: torch.Tensor,
266
+ encoder_hidden_states: Optional[torch.Tensor] = None,
267
+ image_only_indicator: Optional[torch.Tensor] = None,
268
+ return_dict: bool = True,
269
+ ):
270
+ """
271
+ Args:
272
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
273
+ Input hidden_states.
274
+ num_frames (`int`):
275
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
276
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
277
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
278
+ self-attention.
279
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
280
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
281
+ images, 0 indicates that the input contains video frames.
282
+ return_dict (`bool`, *optional*, defaults to `True`):
283
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`]
284
+ instead of a plain tuple.
285
+
286
+ Returns:
287
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
288
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
289
+ returned, otherwise a `tuple` where the first element is the sample tensor.
290
+ """
291
+ # 1. Input
292
+ batch_frames, _, height, width = hidden_states.shape
293
+ num_frames = image_only_indicator.shape[-1]
294
+ batch_size = batch_frames // num_frames
295
+
296
+ time_context = encoder_hidden_states
297
+ time_context_first_timestep = time_context[None, :].reshape(
298
+ batch_size, num_frames, -1, time_context.shape[-1]
299
+ )[:, 0]
300
+ time_context = time_context_first_timestep[None, :].broadcast_to(
301
+ height * width, batch_size, 1, time_context.shape[-1]
302
+ )
303
+ time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
304
+
305
+ residual = hidden_states
306
+
307
+ hidden_states = self.norm(hidden_states)
308
+ inner_dim = hidden_states.shape[1]
309
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
310
+ hidden_states = torch.utils.checkpoint.checkpoint(self.proj_in, hidden_states)
311
+
312
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
313
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1)
314
+ num_frames_emb = num_frames_emb.reshape(-1)
315
+ t_emb = self.time_proj(num_frames_emb)
316
+
317
+ # `Timesteps` does not contain any weights and will always return f32 tensors
318
+ # but time_embedding might actually be running in fp16. so we need to cast here.
319
+ # there might be better ways to encapsulate this.
320
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
321
+
322
+ emb = self.time_pos_embed(t_emb)
323
+ emb = emb[:, None, :]
324
+
325
+ # 2. Blocks
326
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
327
+ if self.gradient_checkpointing:
328
+ hidden_states = torch.utils.checkpoint.checkpoint(
329
+ block,
330
+ hidden_states,
331
+ None,
332
+ encoder_hidden_states,
333
+ None,
334
+ use_reentrant=False,
335
+ )
336
+ else:
337
+ hidden_states = block(
338
+ hidden_states,
339
+ encoder_hidden_states=encoder_hidden_states,
340
+ )
341
+
342
+ hidden_states_mix = hidden_states
343
+ hidden_states_mix = hidden_states_mix + emb
344
+
345
+ if self.gradient_checkpointing:
346
+ hidden_states_mix = torch.utils.checkpoint.checkpoint(
347
+ temporal_block,
348
+ hidden_states_mix,
349
+ num_frames,
350
+ time_context,
351
+ )
352
+ hidden_states = self.time_mixer(
353
+ x_spatial=hidden_states,
354
+ x_temporal=hidden_states_mix,
355
+ image_only_indicator=image_only_indicator,
356
+ )
357
+ else:
358
+ hidden_states_mix = temporal_block(
359
+ hidden_states_mix,
360
+ num_frames=num_frames,
361
+ encoder_hidden_states=time_context,
362
+ )
363
+ hidden_states = self.time_mixer(
364
+ x_spatial=hidden_states,
365
+ x_temporal=hidden_states_mix,
366
+ image_only_indicator=image_only_indicator,
367
+ )
368
+
369
+ # 3. Output
370
+ hidden_states = torch.utils.checkpoint.checkpoint(self.proj_out, hidden_states)
371
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
372
+
373
+ output = hidden_states + residual
374
+
375
+ if not return_dict:
376
+ return (output,)
377
+
378
+ return TransformerTemporalModelOutput(sample=output)
mimicmotion/modules/pose_net.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import einops
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+
9
+
10
+ class PoseNet(nn.Module):
11
+ """a tiny conv network for introducing pose sequence as the condition
12
+ """
13
+ def __init__(self, noise_latent_channels=320, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ # multiple convolution layers
16
+ self.conv_layers = nn.Sequential(
17
+ nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1),
18
+ nn.SiLU(),
19
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=4, stride=2, padding=1),
20
+ nn.SiLU(),
21
+
22
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1),
23
+ nn.SiLU(),
24
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2, padding=1),
25
+ nn.SiLU(),
26
+
27
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
28
+ nn.SiLU(),
29
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
30
+ nn.SiLU(),
31
+
32
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
33
+ nn.SiLU(),
34
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
35
+ nn.SiLU()
36
+ )
37
+
38
+ # Final projection layer
39
+ self.final_proj = nn.Conv2d(in_channels=128, out_channels=noise_latent_channels, kernel_size=1)
40
+
41
+ # Initialize layers
42
+ self._initialize_weights()
43
+
44
+ self.scale = nn.Parameter(torch.ones(1) * 2)
45
+
46
+ def _initialize_weights(self):
47
+ """Initialize weights with He. initialization and zero out the biases
48
+ """
49
+ for m in self.conv_layers:
50
+ if isinstance(m, nn.Conv2d):
51
+ n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
52
+ init.normal_(m.weight, mean=0.0, std=np.sqrt(2. / n))
53
+ if m.bias is not None:
54
+ init.zeros_(m.bias)
55
+ init.zeros_(self.final_proj.weight)
56
+ if self.final_proj.bias is not None:
57
+ init.zeros_(self.final_proj.bias)
58
+
59
+ def forward(self, x):
60
+ if x.ndim == 5:
61
+ x = einops.rearrange(x, "b f c h w -> (b f) c h w")
62
+ x = self.conv_layers(x)
63
+ x = self.final_proj(x)
64
+
65
+ return x * self.scale
66
+
67
+ @classmethod
68
+ def from_pretrained(cls, pretrained_model_path):
69
+ """load pretrained pose-net weights
70
+ """
71
+ if not Path(pretrained_model_path).exists():
72
+ print(f"There is no model file in {pretrained_model_path}")
73
+ print(f"loaded PoseNet's pretrained weights from {pretrained_model_path}.")
74
+
75
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
76
+ model = PoseNet(noise_latent_channels=320)
77
+
78
+ model.load_state_dict(state_dict, strict=True)
79
+
80
+ return model
mimicmotion/modules/unet.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.loaders import UNet2DConditionLoadersMixin
8
+ from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
9
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.utils import BaseOutput, logging
12
+
13
+ from diffusers.models.unets.unet_3d_blocks import get_down_block, get_up_block, UNetMidBlockSpatioTemporal
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ @dataclass
19
+ class UNetSpatioTemporalConditionOutput(BaseOutput):
20
+ """
21
+ The output of [`UNetSpatioTemporalConditionModel`].
22
+
23
+ Args:
24
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
25
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
26
+ """
27
+
28
+ sample: torch.FloatTensor = None
29
+
30
+
31
+ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
32
+ r"""
33
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state,
34
+ and a timestep and returns a sample shaped output.
35
+
36
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
37
+ for all models (such as downloading or saving).
38
+
39
+ Parameters:
40
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
41
+ Height and width of input/output sample.
42
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
43
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
44
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal",
45
+ "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
46
+ The tuple of downsample blocks to use.
47
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal",
48
+ "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
49
+ The tuple of upsample blocks to use.
50
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
51
+ The tuple of output channels for each block.
52
+ addition_time_embed_dim: (`int`, defaults to 256):
53
+ Dimension to to encode the additional time ids.
54
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
55
+ The dimension of the projection of encoded `added_time_ids`.
56
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
57
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
58
+ The dimension of the cross attention features.
59
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
60
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
61
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
62
+ [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
63
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
64
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
65
+ The number of attention heads.
66
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
67
+ """
68
+
69
+ _supports_gradient_checkpointing = True
70
+
71
+ @register_to_config
72
+ def __init__(
73
+ self,
74
+ sample_size: Optional[int] = None,
75
+ in_channels: int = 8,
76
+ out_channels: int = 4,
77
+ down_block_types: Tuple[str] = (
78
+ "CrossAttnDownBlockSpatioTemporal",
79
+ "CrossAttnDownBlockSpatioTemporal",
80
+ "CrossAttnDownBlockSpatioTemporal",
81
+ "DownBlockSpatioTemporal",
82
+ ),
83
+ up_block_types: Tuple[str] = (
84
+ "UpBlockSpatioTemporal",
85
+ "CrossAttnUpBlockSpatioTemporal",
86
+ "CrossAttnUpBlockSpatioTemporal",
87
+ "CrossAttnUpBlockSpatioTemporal",
88
+ ),
89
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
90
+ addition_time_embed_dim: int = 256,
91
+ projection_class_embeddings_input_dim: int = 768,
92
+ layers_per_block: Union[int, Tuple[int]] = 2,
93
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
94
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
95
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
96
+ num_frames: int = 25,
97
+ ):
98
+ super().__init__()
99
+
100
+ self.sample_size = sample_size
101
+
102
+ # Check inputs
103
+ if len(down_block_types) != len(up_block_types):
104
+ raise ValueError(
105
+ f"Must provide the same number of `down_block_types` as `up_block_types`. " \
106
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
107
+ )
108
+
109
+ if len(block_out_channels) != len(down_block_types):
110
+ raise ValueError(
111
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. " \
112
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
113
+ )
114
+
115
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
116
+ raise ValueError(
117
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. " \
118
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
119
+ )
120
+
121
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
122
+ raise ValueError(
123
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " \
124
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
125
+ )
126
+
127
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
128
+ raise ValueError(
129
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. " \
130
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
131
+ )
132
+
133
+ # input
134
+ self.conv_in = nn.Conv2d(
135
+ in_channels,
136
+ block_out_channels[0],
137
+ kernel_size=3,
138
+ padding=1,
139
+ )
140
+
141
+ # time
142
+ time_embed_dim = block_out_channels[0] * 4
143
+
144
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
145
+ timestep_input_dim = block_out_channels[0]
146
+
147
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
148
+
149
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
150
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
151
+
152
+ self.down_blocks = nn.ModuleList([])
153
+ self.up_blocks = nn.ModuleList([])
154
+
155
+ if isinstance(num_attention_heads, int):
156
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
157
+
158
+ if isinstance(cross_attention_dim, int):
159
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
160
+
161
+ if isinstance(layers_per_block, int):
162
+ layers_per_block = [layers_per_block] * len(down_block_types)
163
+
164
+ if isinstance(transformer_layers_per_block, int):
165
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
166
+
167
+ blocks_time_embed_dim = time_embed_dim
168
+
169
+ # down
170
+ output_channel = block_out_channels[0]
171
+ for i, down_block_type in enumerate(down_block_types):
172
+ input_channel = output_channel
173
+ output_channel = block_out_channels[i]
174
+ is_final_block = i == len(block_out_channels) - 1
175
+
176
+ down_block = get_down_block(
177
+ down_block_type,
178
+ num_layers=layers_per_block[i],
179
+ transformer_layers_per_block=transformer_layers_per_block[i],
180
+ in_channels=input_channel,
181
+ out_channels=output_channel,
182
+ temb_channels=blocks_time_embed_dim,
183
+ add_downsample=not is_final_block,
184
+ resnet_eps=1e-5,
185
+ cross_attention_dim=cross_attention_dim[i],
186
+ num_attention_heads=num_attention_heads[i],
187
+ resnet_act_fn="silu",
188
+ )
189
+ self.down_blocks.append(down_block)
190
+
191
+ # mid
192
+ self.mid_block = UNetMidBlockSpatioTemporal(
193
+ block_out_channels[-1],
194
+ temb_channels=blocks_time_embed_dim,
195
+ transformer_layers_per_block=transformer_layers_per_block[-1],
196
+ cross_attention_dim=cross_attention_dim[-1],
197
+ num_attention_heads=num_attention_heads[-1],
198
+ )
199
+
200
+ # count how many layers upsample the images
201
+ self.num_upsamplers = 0
202
+
203
+ # up
204
+ reversed_block_out_channels = list(reversed(block_out_channels))
205
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
206
+ reversed_layers_per_block = list(reversed(layers_per_block))
207
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
208
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
209
+
210
+ output_channel = reversed_block_out_channels[0]
211
+ for i, up_block_type in enumerate(up_block_types):
212
+ is_final_block = i == len(block_out_channels) - 1
213
+
214
+ prev_output_channel = output_channel
215
+ output_channel = reversed_block_out_channels[i]
216
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
217
+
218
+ # add upsample block for all BUT final layer
219
+ if not is_final_block:
220
+ add_upsample = True
221
+ self.num_upsamplers += 1
222
+ else:
223
+ add_upsample = False
224
+
225
+ up_block = get_up_block(
226
+ up_block_type,
227
+ num_layers=reversed_layers_per_block[i] + 1,
228
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
229
+ in_channels=input_channel,
230
+ out_channels=output_channel,
231
+ prev_output_channel=prev_output_channel,
232
+ temb_channels=blocks_time_embed_dim,
233
+ add_upsample=add_upsample,
234
+ resnet_eps=1e-5,
235
+ resolution_idx=i,
236
+ cross_attention_dim=reversed_cross_attention_dim[i],
237
+ num_attention_heads=reversed_num_attention_heads[i],
238
+ resnet_act_fn="silu",
239
+ )
240
+ self.up_blocks.append(up_block)
241
+ prev_output_channel = output_channel
242
+
243
+ # out
244
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
245
+ self.conv_act = nn.SiLU()
246
+
247
+ self.conv_out = nn.Conv2d(
248
+ block_out_channels[0],
249
+ out_channels,
250
+ kernel_size=3,
251
+ padding=1,
252
+ )
253
+
254
+ @property
255
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
256
+ r"""
257
+ Returns:
258
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
259
+ indexed by its weight name.
260
+ """
261
+ # set recursively
262
+ processors = {}
263
+
264
+ def fn_recursive_add_processors(
265
+ name: str,
266
+ module: torch.nn.Module,
267
+ processors: Dict[str, AttentionProcessor],
268
+ ):
269
+ if hasattr(module, "get_processor"):
270
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
271
+
272
+ for sub_name, child in module.named_children():
273
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274
+
275
+ return processors
276
+
277
+ for name, module in self.named_children():
278
+ fn_recursive_add_processors(name, module, processors)
279
+
280
+ return processors
281
+
282
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
283
+ r"""
284
+ Sets the attention processor to use to compute attention.
285
+
286
+ Parameters:
287
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
288
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
289
+ for **all** `Attention` layers.
290
+
291
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
292
+ processor. This is strongly recommended when setting trainable attention processors.
293
+
294
+ """
295
+ count = len(self.attn_processors.keys())
296
+
297
+ if isinstance(processor, dict) and len(processor) != count:
298
+ raise ValueError(
299
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
300
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
301
+ )
302
+
303
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
304
+ if hasattr(module, "set_processor"):
305
+ if not isinstance(processor, dict):
306
+ module.set_processor(processor)
307
+ else:
308
+ module.set_processor(processor.pop(f"{name}.processor"))
309
+
310
+ for sub_name, child in module.named_children():
311
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
312
+
313
+ for name, module in self.named_children():
314
+ fn_recursive_attn_processor(name, module, processor)
315
+
316
+ def set_default_attn_processor(self):
317
+ """
318
+ Disables custom attention processors and sets the default attention implementation.
319
+ """
320
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
321
+ processor = AttnProcessor()
322
+ else:
323
+ raise ValueError(
324
+ f"Cannot call `set_default_attn_processor` " \
325
+ f"when attention processors are of type {next(iter(self.attn_processors.values()))}"
326
+ )
327
+
328
+ self.set_attn_processor(processor)
329
+
330
+ def _set_gradient_checkpointing(self, module, value=False):
331
+ if hasattr(module, "gradient_checkpointing"):
332
+ module.gradient_checkpointing = value
333
+
334
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
335
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
336
+ """
337
+ Sets the attention processor to use [feed forward
338
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
339
+
340
+ Parameters:
341
+ chunk_size (`int`, *optional*):
342
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
343
+ over each tensor of dim=`dim`.
344
+ dim (`int`, *optional*, defaults to `0`):
345
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
346
+ or dim=1 (sequence length).
347
+ """
348
+ if dim not in [0, 1]:
349
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
350
+
351
+ # By default chunk size is 1
352
+ chunk_size = chunk_size or 1
353
+
354
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
355
+ if hasattr(module, "set_chunk_feed_forward"):
356
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
357
+
358
+ for child in module.children():
359
+ fn_recursive_feed_forward(child, chunk_size, dim)
360
+
361
+ for module in self.children():
362
+ fn_recursive_feed_forward(module, chunk_size, dim)
363
+
364
+ def forward(
365
+ self,
366
+ sample: torch.FloatTensor,
367
+ timestep: Union[torch.Tensor, float, int],
368
+ encoder_hidden_states: torch.Tensor,
369
+ added_time_ids: torch.Tensor,
370
+ pose_latents: torch.Tensor = None,
371
+ image_only_indicator: bool = False,
372
+ return_dict: bool = True,
373
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
374
+ r"""
375
+ The [`UNetSpatioTemporalConditionModel`] forward method.
376
+
377
+ Args:
378
+ sample (`torch.FloatTensor`):
379
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
380
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
381
+ encoder_hidden_states (`torch.FloatTensor`):
382
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
383
+ added_time_ids: (`torch.FloatTensor`):
384
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
385
+ embeddings and added to the time embeddings.
386
+ pose_latents: (`torch.FloatTensor`):
387
+ The additional latents for pose sequences.
388
+ image_only_indicator (`bool`, *optional*, defaults to `False`):
389
+ Whether or not training with all images.
390
+ return_dict (`bool`, *optional*, defaults to `True`):
391
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`]
392
+ instead of a plain tuple.
393
+ Returns:
394
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
395
+ If `return_dict` is True,
396
+ an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned,
397
+ otherwise a `tuple` is returned where the first element is the sample tensor.
398
+ """
399
+ # 1. time
400
+ timesteps = timestep
401
+ if not torch.is_tensor(timesteps):
402
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
403
+ # This would be a good case for the `match` statement (Python 3.10+)
404
+ is_mps = sample.device.type == "mps"
405
+ if isinstance(timestep, float):
406
+ dtype = torch.float32 if is_mps else torch.float64
407
+ else:
408
+ dtype = torch.int32 if is_mps else torch.int64
409
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
410
+ elif len(timesteps.shape) == 0:
411
+ timesteps = timesteps[None].to(sample.device)
412
+
413
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
414
+ batch_size, num_frames = sample.shape[:2]
415
+ timesteps = timesteps.expand(batch_size)
416
+
417
+ t_emb = self.time_proj(timesteps)
418
+
419
+ # `Timesteps` does not contain any weights and will always return f32 tensors
420
+ # but time_embedding might actually be running in fp16. so we need to cast here.
421
+ # there might be better ways to encapsulate this.
422
+ t_emb = t_emb.to(dtype=sample.dtype)
423
+
424
+ emb = self.time_embedding(t_emb)
425
+
426
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
427
+ time_embeds = time_embeds.reshape((batch_size, -1))
428
+ time_embeds = time_embeds.to(emb.dtype)
429
+ aug_emb = self.add_embedding(time_embeds)
430
+ emb = emb + aug_emb
431
+
432
+ # Flatten the batch and frames dimensions
433
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
434
+ sample = sample.flatten(0, 1)
435
+ # Repeat the embeddings num_video_frames times
436
+ # emb: [batch, channels] -> [batch * frames, channels]
437
+ emb = emb.repeat_interleave(num_frames, dim=0)
438
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
439
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
440
+
441
+ # 2. pre-process
442
+ sample = self.conv_in(sample)
443
+ if pose_latents is not None:
444
+ sample = sample + pose_latents
445
+
446
+ image_only_indicator = torch.ones(batch_size, num_frames, dtype=sample.dtype, device=sample.device) \
447
+ if image_only_indicator else torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
448
+
449
+ down_block_res_samples = (sample,)
450
+ for downsample_block in self.down_blocks:
451
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
452
+ sample, res_samples = downsample_block(
453
+ hidden_states=sample,
454
+ temb=emb,
455
+ encoder_hidden_states=encoder_hidden_states,
456
+ image_only_indicator=image_only_indicator,
457
+ )
458
+ else:
459
+ sample, res_samples = downsample_block(
460
+ hidden_states=sample,
461
+ temb=emb,
462
+ image_only_indicator=image_only_indicator,
463
+ )
464
+
465
+ down_block_res_samples += res_samples
466
+
467
+ # 4. mid
468
+ sample = self.mid_block(
469
+ hidden_states=sample,
470
+ temb=emb,
471
+ encoder_hidden_states=encoder_hidden_states,
472
+ image_only_indicator=image_only_indicator,
473
+ )
474
+
475
+ # 5. up
476
+ for i, upsample_block in enumerate(self.up_blocks):
477
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
478
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
479
+
480
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
481
+ sample = upsample_block(
482
+ hidden_states=sample,
483
+ temb=emb,
484
+ res_hidden_states_tuple=res_samples,
485
+ encoder_hidden_states=encoder_hidden_states,
486
+ image_only_indicator=image_only_indicator,
487
+ )
488
+ else:
489
+ sample = upsample_block(
490
+ hidden_states=sample,
491
+ temb=emb,
492
+ res_hidden_states_tuple=res_samples,
493
+ image_only_indicator=image_only_indicator,
494
+ )
495
+
496
+ # 6. post-process
497
+ sample = self.conv_norm_out(sample)
498
+ sample = self.conv_act(sample)
499
+ sample = self.conv_out(sample)
500
+
501
+ # 7. Reshape back to original shape
502
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
503
+
504
+ if not return_dict:
505
+ return (sample,)
506
+
507
+ return UNetSpatioTemporalConditionOutput(sample=sample)
mimicmotion/pipelines/pipeline_mimicmotion.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, Dict, List, Optional, Union
4
+
5
+ import PIL.Image
6
+ import einops
7
+ import numpy as np
8
+ import torch
9
+ from diffusers.image_processor import VaeImageProcessor, PipelineImageInput
10
+ from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
11
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
12
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
13
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion \
14
+ import _resize_with_antialiasing, _append_dims
15
+ from diffusers.schedulers import EulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging
17
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
18
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
19
+
20
+ from ..modules.pose_net import PoseNet
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ def _append_dims(x, target_dims):
26
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
27
+ dims_to_append = target_dims - x.ndim
28
+ if dims_to_append < 0:
29
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
30
+ return x[(...,) + (None,) * dims_to_append]
31
+
32
+
33
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
34
+ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
35
+ batch_size, channels, num_frames, height, width = video.shape
36
+ outputs = []
37
+ for batch_idx in range(batch_size):
38
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
39
+ batch_output = processor.postprocess(batch_vid, output_type)
40
+
41
+ outputs.append(batch_output)
42
+
43
+ if output_type == "np":
44
+ outputs = np.stack(outputs)
45
+
46
+ elif output_type == "pt":
47
+ outputs = torch.stack(outputs)
48
+
49
+ elif not output_type == "pil":
50
+ raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
51
+
52
+ return outputs
53
+
54
+
55
+ @dataclass
56
+ class MimicMotionPipelineOutput(BaseOutput):
57
+ r"""
58
+ Output class for mimicmotion pipeline.
59
+
60
+ Args:
61
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
62
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
63
+ num_frames, height, width, num_channels)`.
64
+ """
65
+
66
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
67
+
68
+
69
+ class MimicMotionPipeline(DiffusionPipeline):
70
+ r"""
71
+ Pipeline to generate video from an input image using Stable Video Diffusion.
72
+
73
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
74
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
75
+
76
+ Args:
77
+ vae ([`AutoencoderKLTemporalDecoder`]):
78
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
79
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
80
+ Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K]
81
+ (https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
82
+ unet ([`UNetSpatioTemporalConditionModel`]):
83
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
84
+ scheduler ([`EulerDiscreteScheduler`]):
85
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
86
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
87
+ A `CLIPImageProcessor` to extract features from generated images.
88
+ pose_net ([`PoseNet`]):
89
+ A `` to inject pose signals into unet.
90
+ """
91
+
92
+ model_cpu_offload_seq = "image_encoder->unet->vae"
93
+ _callback_tensor_inputs = ["latents"]
94
+
95
+ def __init__(
96
+ self,
97
+ vae: AutoencoderKLTemporalDecoder,
98
+ image_encoder: CLIPVisionModelWithProjection,
99
+ unet: UNetSpatioTemporalConditionModel,
100
+ scheduler: EulerDiscreteScheduler,
101
+ feature_extractor: CLIPImageProcessor,
102
+ pose_net: PoseNet,
103
+ ):
104
+ super().__init__()
105
+
106
+ self.register_modules(
107
+ vae=vae,
108
+ image_encoder=image_encoder,
109
+ unet=unet,
110
+ scheduler=scheduler,
111
+ feature_extractor=feature_extractor,
112
+ pose_net=pose_net,
113
+ )
114
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
115
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
116
+
117
+ def _encode_image(
118
+ self,
119
+ image: PipelineImageInput,
120
+ device: Union[str, torch.device],
121
+ num_videos_per_prompt: int,
122
+ do_classifier_free_guidance: bool):
123
+ dtype = next(self.image_encoder.parameters()).dtype
124
+
125
+ if not isinstance(image, torch.Tensor):
126
+ image = self.image_processor.pil_to_numpy(image)
127
+ image = self.image_processor.numpy_to_pt(image)
128
+
129
+ # We normalize the image before resizing to match with the original implementation.
130
+ # Then we unnormalize it after resizing.
131
+ image = image * 2.0 - 1.0
132
+ image = _resize_with_antialiasing(image, (224, 224))
133
+ image = (image + 1.0) / 2.0
134
+
135
+ # Normalize the image with for CLIP input
136
+ image = self.feature_extractor(
137
+ images=image,
138
+ do_normalize=True,
139
+ do_center_crop=False,
140
+ do_resize=False,
141
+ do_rescale=False,
142
+ return_tensors="pt",
143
+ ).pixel_values
144
+
145
+ image = image.to(device=device, dtype=dtype)
146
+ image_embeddings = self.image_encoder(image).image_embeds
147
+ image_embeddings = image_embeddings.unsqueeze(1)
148
+
149
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
150
+ bs_embed, seq_len, _ = image_embeddings.shape
151
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
152
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
153
+
154
+ if do_classifier_free_guidance:
155
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
156
+
157
+ # For classifier free guidance, we need to do two forward passes.
158
+ # Here we concatenate the unconditional and text embeddings into a single batch
159
+ # to avoid doing two forward passes
160
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
161
+
162
+ return image_embeddings
163
+
164
+ def _encode_vae_image(
165
+ self,
166
+ image: torch.Tensor,
167
+ device: Union[str, torch.device],
168
+ num_videos_per_prompt: int,
169
+ do_classifier_free_guidance: bool,
170
+ ):
171
+ image = image.to(device=device, dtype=self.vae.dtype)
172
+ image_latents = self.vae.encode(image).latent_dist.mode()
173
+
174
+ if do_classifier_free_guidance:
175
+ negative_image_latents = torch.zeros_like(image_latents)
176
+
177
+ # For classifier free guidance, we need to do two forward passes.
178
+ # Here we concatenate the unconditional and text embeddings into a single batch
179
+ # to avoid doing two forward passes
180
+ image_latents = torch.cat([negative_image_latents, image_latents])
181
+
182
+ # duplicate image_latents for each generation per prompt, using mps friendly method
183
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
184
+
185
+ return image_latents
186
+
187
+ def _get_add_time_ids(
188
+ self,
189
+ fps: int,
190
+ motion_bucket_id: int,
191
+ noise_aug_strength: float,
192
+ dtype: torch.dtype,
193
+ batch_size: int,
194
+ num_videos_per_prompt: int,
195
+ do_classifier_free_guidance: bool,
196
+ ):
197
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
198
+
199
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
200
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
201
+
202
+ if expected_add_embed_dim != passed_add_embed_dim:
203
+ raise ValueError(
204
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
205
+ f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. " \
206
+ f"Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
207
+ )
208
+
209
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
210
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
211
+
212
+ if do_classifier_free_guidance:
213
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
214
+
215
+ return add_time_ids
216
+
217
+ def decode_latents(
218
+ self,
219
+ latents: torch.Tensor,
220
+ num_frames: int,
221
+ decode_chunk_size: int = 8):
222
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
223
+ latents = latents.flatten(0, 1)
224
+
225
+ latents = 1 / self.vae.config.scaling_factor * latents
226
+
227
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
228
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
229
+
230
+ # decode decode_chunk_size frames at a time to avoid OOM
231
+ frames = []
232
+ for i in range(0, latents.shape[0], decode_chunk_size):
233
+ num_frames_in = latents[i: i + decode_chunk_size].shape[0]
234
+ decode_kwargs = {}
235
+ if accepts_num_frames:
236
+ # we only pass num_frames_in if it's expected
237
+ decode_kwargs["num_frames"] = num_frames_in
238
+
239
+ frame = self.vae.decode(latents[i: i + decode_chunk_size], **decode_kwargs).sample
240
+ frames.append(frame.cpu())
241
+ frames = torch.cat(frames, dim=0)
242
+
243
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
244
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
245
+
246
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
247
+ frames = frames.float()
248
+ return frames
249
+
250
+ def check_inputs(self, image, height, width):
251
+ if (
252
+ not isinstance(image, torch.Tensor)
253
+ and not isinstance(image, PIL.Image.Image)
254
+ and not isinstance(image, list)
255
+ ):
256
+ raise ValueError(
257
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
258
+ f" {type(image)}"
259
+ )
260
+
261
+ if height % 8 != 0 or width % 8 != 0:
262
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
263
+
264
+ def prepare_latents(
265
+ self,
266
+ batch_size: int,
267
+ num_frames: int,
268
+ num_channels_latents: int,
269
+ height: int,
270
+ width: int,
271
+ dtype: torch.dtype,
272
+ device: Union[str, torch.device],
273
+ generator: torch.Generator,
274
+ latents: Optional[torch.Tensor] = None,
275
+ ):
276
+ shape = (
277
+ batch_size,
278
+ num_frames,
279
+ num_channels_latents // 2,
280
+ height // self.vae_scale_factor,
281
+ width // self.vae_scale_factor,
282
+ )
283
+ if isinstance(generator, list) and len(generator) != batch_size:
284
+ raise ValueError(
285
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
286
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
287
+ )
288
+
289
+ if latents is None:
290
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
291
+ else:
292
+ latents = latents.to(device)
293
+
294
+ # scale the initial noise by the standard deviation required by the scheduler
295
+ latents = latents * self.scheduler.init_noise_sigma
296
+ return latents
297
+
298
+ @property
299
+ def guidance_scale(self):
300
+ return self._guidance_scale
301
+
302
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
303
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
304
+ # corresponds to doing no classifier free guidance.
305
+ @property
306
+ def do_classifier_free_guidance(self):
307
+ if isinstance(self.guidance_scale, (int, float)):
308
+ return self.guidance_scale > 1
309
+ return self.guidance_scale.max() > 1
310
+
311
+ @property
312
+ def num_timesteps(self):
313
+ return self._num_timesteps
314
+
315
+ def prepare_extra_step_kwargs(self, generator, eta):
316
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
317
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
318
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
319
+ # and should be between [0, 1]
320
+
321
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
322
+ extra_step_kwargs = {}
323
+ if accepts_eta:
324
+ extra_step_kwargs["eta"] = eta
325
+
326
+ # check if the scheduler accepts generator
327
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
328
+ if accepts_generator:
329
+ extra_step_kwargs["generator"] = generator
330
+ return extra_step_kwargs
331
+
332
+ @torch.no_grad()
333
+ def __call__(
334
+ self,
335
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
336
+ image_pose: Union[torch.FloatTensor],
337
+ height: int = 576,
338
+ width: int = 1024,
339
+ num_frames: Optional[int] = None,
340
+ tile_size: Optional[int] = 16,
341
+ tile_overlap: Optional[int] = 4,
342
+ num_inference_steps: int = 25,
343
+ min_guidance_scale: float = 1.0,
344
+ max_guidance_scale: float = 3.0,
345
+ fps: int = 7,
346
+ motion_bucket_id: int = 127,
347
+ noise_aug_strength: float = 0.02,
348
+ image_only_indicator: bool = False,
349
+ decode_chunk_size: Optional[int] = None,
350
+ num_videos_per_prompt: Optional[int] = 1,
351
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
352
+ latents: Optional[torch.FloatTensor] = None,
353
+ output_type: Optional[str] = "pil",
354
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
355
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
356
+ return_dict: bool = True,
357
+ device: Union[str, torch.device] =None,
358
+ ):
359
+ r"""
360
+ The call function to the pipeline for generation.
361
+
362
+ Args:
363
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
364
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
365
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/
366
+ feature_extractor/preprocessor_config.json).
367
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
368
+ The height in pixels of the generated image.
369
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
370
+ The width in pixels of the generated image.
371
+ num_frames (`int`, *optional*):
372
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid`
373
+ and to 25 for `stable-video-diffusion-img2vid-xt`
374
+ num_inference_steps (`int`, *optional*, defaults to 25):
375
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
376
+ expense of slower inference. This parameter is modulated by `strength`.
377
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
378
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
379
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
380
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
381
+ fps (`int`, *optional*, defaults to 7):
382
+ Frames per second.The rate at which the generated images shall be exported to a video after generation.
383
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
384
+ motion_bucket_id (`int`, *optional*, defaults to 127):
385
+ The motion bucket ID. Used as conditioning for the generation.
386
+ The higher the number the more motion will be in the video.
387
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
388
+ The amount of noise added to the init image,
389
+ the higher it is the less the video will look like the init image. Increase it for more motion.
390
+ image_only_indicator (`bool`, *optional*, defaults to False):
391
+ Whether to treat the inputs as batch of images instead of videos.
392
+ decode_chunk_size (`int`, *optional*):
393
+ The number of frames to decode at a time.The higher the chunk size, the higher the temporal consistency
394
+ between frames, but also the higher the memory consumption.
395
+ By default, the decoder will decode all frames at once for maximal quality.
396
+ Reduce `decode_chunk_size` to reduce memory usage.
397
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
398
+ The number of images to generate per prompt.
399
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
400
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
401
+ generation deterministic.
402
+ latents (`torch.FloatTensor`, *optional*):
403
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
404
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
405
+ tensor is generated by sampling using the supplied random `generator`.
406
+ output_type (`str`, *optional*, defaults to `"pil"`):
407
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
408
+ callback_on_step_end (`Callable`, *optional*):
409
+ A function that calls at the end of each denoising steps during the inference. The function is called
410
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
411
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
412
+ `callback_on_step_end_tensor_inputs`.
413
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
414
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
415
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
416
+ `._callback_tensor_inputs` attribute of your pipeline class.
417
+ return_dict (`bool`, *optional*, defaults to `True`):
418
+ Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
419
+ plain tuple.
420
+ device:
421
+ On which device the pipeline runs on.
422
+
423
+ Returns:
424
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
425
+ If `return_dict` is `True`,
426
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
427
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
428
+
429
+ Examples:
430
+
431
+ ```py
432
+ from diffusers import StableVideoDiffusionPipeline
433
+ from diffusers.utils import load_image, export_to_video
434
+
435
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
436
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
437
+ pipe.to("cuda")
438
+
439
+ image = load_image(
440
+ "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
441
+ image = image.resize((1024, 576))
442
+
443
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
444
+ export_to_video(frames, "generated.mp4", fps=7)
445
+ ```
446
+ """
447
+ # 0. Default height and width to unet
448
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
449
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
450
+
451
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
452
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
453
+
454
+ # 1. Check inputs. Raise error if not correct
455
+ self.check_inputs(image, height, width)
456
+
457
+ # 2. Define call parameters
458
+ if isinstance(image, PIL.Image.Image):
459
+ batch_size = 1
460
+ elif isinstance(image, list):
461
+ batch_size = len(image)
462
+ else:
463
+ batch_size = image.shape[0]
464
+ device = device if device is not None else self._execution_device
465
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
466
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
467
+ # corresponds to doing no classifier free guidance.
468
+ self._guidance_scale = max_guidance_scale
469
+
470
+ # 3. Encode input image
471
+ self.image_encoder.to(device)
472
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
473
+ self.image_encoder.cpu()
474
+
475
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
476
+ # is why it is reduced here.
477
+ fps = fps - 1
478
+
479
+ # 4. Encode input image using VAE
480
+ image = self.image_processor.preprocess(image, height=height, width=width).to(device)
481
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
482
+ image = image + noise_aug_strength * noise
483
+
484
+ self.vae.to(device)
485
+ image_latents = self._encode_vae_image(
486
+ image,
487
+ device=device,
488
+ num_videos_per_prompt=num_videos_per_prompt,
489
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
490
+ )
491
+ image_latents = image_latents.to(image_embeddings.dtype)
492
+ self.vae.cpu()
493
+
494
+ # Repeat the image latents for each frame so we can concatenate them with the noise
495
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
496
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
497
+
498
+ # 5. Get Added Time IDs
499
+ added_time_ids = self._get_add_time_ids(
500
+ fps,
501
+ motion_bucket_id,
502
+ noise_aug_strength,
503
+ image_embeddings.dtype,
504
+ batch_size,
505
+ num_videos_per_prompt,
506
+ self.do_classifier_free_guidance,
507
+ )
508
+ added_time_ids = added_time_ids.to(device)
509
+
510
+ # 4. Prepare timesteps
511
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None)
512
+
513
+ # 5. Prepare latent variables
514
+ num_channels_latents = self.unet.config.in_channels
515
+ latents = self.prepare_latents(
516
+ batch_size * num_videos_per_prompt,
517
+ tile_size,
518
+ num_channels_latents,
519
+ height,
520
+ width,
521
+ image_embeddings.dtype,
522
+ device,
523
+ generator,
524
+ latents,
525
+ )
526
+ latents = latents.repeat(1, num_frames // tile_size + 1, 1, 1, 1)[:, :num_frames]
527
+
528
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
529
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
530
+
531
+ # 7. Prepare guidance scale
532
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
533
+ guidance_scale = guidance_scale.to(device, latents.dtype)
534
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
535
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
536
+
537
+ self._guidance_scale = guidance_scale
538
+
539
+ # 8. Denoising loop
540
+ self._num_timesteps = len(timesteps)
541
+ indices = [[0, *range(i + 1, min(i + tile_size, num_frames))] for i in
542
+ range(0, num_frames - tile_size + 1, tile_size - tile_overlap)]
543
+ if indices[-1][-1] < num_frames - 1:
544
+ indices.append([0, *range(num_frames - tile_size + 1, num_frames)])
545
+
546
+ self.pose_net.to(device)
547
+ self.unet.to(device)
548
+
549
+ with torch.cuda.device(device):
550
+ torch.cuda.empty_cache()
551
+
552
+ with self.progress_bar(total=len(timesteps) * len(indices)) as progress_bar:
553
+ for i, t in enumerate(timesteps):
554
+ # expand the latents if we are doing classifier free guidance
555
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
556
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
557
+
558
+ # Concatenate image_latents over channels dimension
559
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
560
+
561
+ # predict the noise residual
562
+ noise_pred = torch.zeros_like(image_latents)
563
+ noise_pred_cnt = image_latents.new_zeros((num_frames,))
564
+ weight = (torch.arange(tile_size, device=device) + 0.5) * 2. / tile_size
565
+ weight = torch.minimum(weight, 2 - weight)
566
+ for idx in indices:
567
+
568
+ # classification-free inference
569
+ pose_latents = self.pose_net(image_pose[idx].to(device))
570
+ _noise_pred = self.unet(
571
+ latent_model_input[:1, idx],
572
+ t,
573
+ encoder_hidden_states=image_embeddings[:1],
574
+ added_time_ids=added_time_ids[:1],
575
+ pose_latents=None,
576
+ image_only_indicator=image_only_indicator,
577
+ return_dict=False,
578
+ )[0]
579
+ noise_pred[:1, idx] += _noise_pred * weight[:, None, None, None]
580
+
581
+ # normal inference
582
+ _noise_pred = self.unet(
583
+ latent_model_input[1:, idx],
584
+ t,
585
+ encoder_hidden_states=image_embeddings[1:],
586
+ added_time_ids=added_time_ids[1:],
587
+ pose_latents=pose_latents,
588
+ image_only_indicator=image_only_indicator,
589
+ return_dict=False,
590
+ )[0]
591
+ noise_pred[1:, idx] += _noise_pred * weight[:, None, None, None]
592
+
593
+ noise_pred_cnt[idx] += weight
594
+ progress_bar.update()
595
+ noise_pred.div_(noise_pred_cnt[:, None, None, None])
596
+
597
+ # perform guidance
598
+ if self.do_classifier_free_guidance:
599
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
600
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
601
+
602
+ # compute the previous noisy sample x_t -> x_t-1
603
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
604
+
605
+ if callback_on_step_end is not None:
606
+ callback_kwargs = {}
607
+ for k in callback_on_step_end_tensor_inputs:
608
+ callback_kwargs[k] = locals()[k]
609
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
610
+
611
+ latents = callback_outputs.pop("latents", latents)
612
+
613
+ self.pose_net.cpu()
614
+ self.unet.cpu()
615
+
616
+ if not output_type == "latent":
617
+ self.vae.decoder.to(device)
618
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
619
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
620
+ else:
621
+ frames = latents
622
+
623
+ self.maybe_free_model_hooks()
624
+
625
+ if not return_dict:
626
+ return frames
627
+
628
+ return MimicMotionPipelineOutput(frames=frames)
mimicmotion/utils/__init__.py ADDED
File without changes
mimicmotion/utils/geglu_patch.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers.models.activations
2
+
3
+
4
+ def patch_geglu_inplace():
5
+ """Patch GEGLU with inplace multiplication to save GPU memory."""
6
+ def forward(self, hidden_states):
7
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
8
+ return hidden_states.mul_(self.gelu(gate))
9
+ diffusers.models.activations.GEGLU.forward = forward
mimicmotion/utils/loader.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from diffusers.models import AutoencoderKLTemporalDecoder
6
+ from diffusers.schedulers import EulerDiscreteScheduler
7
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
8
+
9
+ from ..modules.unet import UNetSpatioTemporalConditionModel
10
+ from ..modules.pose_net import PoseNet
11
+ from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class MimicMotionModel(torch.nn.Module):
16
+ def __init__(self, base_model_path):
17
+ """construnct base model components and load pretrained svd model except pose-net
18
+ Args:
19
+ base_model_path (str): pretrained svd model path
20
+ """
21
+ super().__init__()
22
+ self.unet = UNetSpatioTemporalConditionModel.from_config(
23
+ UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet"))
24
+ self.vae = AutoencoderKLTemporalDecoder.from_pretrained(
25
+ base_model_path, subfolder="vae", torch_dtype=torch.float16, variant="fp16")
26
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
27
+ base_model_path, subfolder="image_encoder", torch_dtype=torch.float16, variant="fp16")
28
+ self.noise_scheduler = EulerDiscreteScheduler.from_pretrained(
29
+ base_model_path, subfolder="scheduler")
30
+ self.feature_extractor = CLIPImageProcessor.from_pretrained(
31
+ base_model_path, subfolder="feature_extractor")
32
+ # pose_net
33
+ self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0])
34
+
35
+ def create_pipeline(infer_config, device):
36
+ """create mimicmotion pipeline and load pretrained weight
37
+
38
+ Args:
39
+ infer_config (str):
40
+ device (str or torch.device): "cpu" or "cuda:{device_id}"
41
+ """
42
+ mimicmotion_models = MimicMotionModel(infer_config.base_model_path)
43
+ mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location="cpu"), strict=False)
44
+ pipeline = MimicMotionPipeline(
45
+ vae=mimicmotion_models.vae,
46
+ image_encoder=mimicmotion_models.image_encoder,
47
+ unet=mimicmotion_models.unet,
48
+ scheduler=mimicmotion_models.noise_scheduler,
49
+ feature_extractor=mimicmotion_models.feature_extractor,
50
+ pose_net=mimicmotion_models.pose_net
51
+ )
52
+ return pipeline
53
+
mimicmotion/utils/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from torchvision.io import write_video
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ def save_to_mp4(frames, save_path, fps=7):
9
+ frames = frames.permute((0, 2, 3, 1)) # (f, c, h, w) to (f, h, w, c)
10
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
11
+ write_video(save_path, frames, fps=fps)
12
+
predict.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # predict.py
2
+ import subprocess
3
+ import time
4
+ from cog import BasePredictor, Input, Path
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ from PIL import Image
9
+ from omegaconf import OmegaConf
10
+ from datetime import datetime
11
+
12
+ from torchvision.transforms.functional import pil_to_tensor, resize, center_crop
13
+ from constants import ASPECT_RATIO
14
+
15
+ MODEL_CACHE = "models"
16
+ os.environ["HF_DATASETS_OFFLINE"] = "1"
17
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
18
+ os.environ["HF_HOME"] = MODEL_CACHE
19
+ os.environ["TORCH_HOME"] = MODEL_CACHE
20
+ os.environ["HF_DATASETS_CACHE"] = MODEL_CACHE
21
+ os.environ["TRANSFORMERS_CACHE"] = MODEL_CACHE
22
+ os.environ["HUGGINGFACE_HUB_CACHE"] = MODEL_CACHE
23
+
24
+ BASE_URL = f"https://weights.replicate.delivery/default/MimicMotion/{MODEL_CACHE}/"
25
+
26
+
27
+ def download_weights(url: str, dest: str) -> None:
28
+ # NOTE WHEN YOU EXTRACT SPECIFY THE PARENT FOLDER
29
+ start = time.time()
30
+ print("[!] Initiating download from URL: ", url)
31
+ print("[~] Destination path: ", dest)
32
+ if ".tar" in dest:
33
+ dest = os.path.dirname(dest)
34
+ command = ["pget", "-vf" + ("x" if ".tar" in url else ""), url, dest]
35
+ try:
36
+ print(f"[~] Running command: {' '.join(command)}")
37
+ subprocess.check_call(command, close_fds=False)
38
+ except subprocess.CalledProcessError as e:
39
+ print(
40
+ f"[ERROR] Failed to download weights. Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}."
41
+ )
42
+ raise
43
+ print("[+] Download completed in: ", time.time() - start, "seconds")
44
+
45
+
46
+ class Predictor(BasePredictor):
47
+ def setup(self):
48
+ """Load the model into memory to make running multiple predictions efficient"""
49
+
50
+ if not os.path.exists(MODEL_CACHE):
51
+ os.makedirs(MODEL_CACHE)
52
+ model_files = [
53
+ "DWPose.tar",
54
+ "MimicMotion.pth",
55
+ "MimicMotion_1-1.pth",
56
+ "SVD.tar",
57
+ ]
58
+ for model_file in model_files:
59
+ url = BASE_URL + model_file
60
+ filename = url.split("/")[-1]
61
+ dest_path = os.path.join(MODEL_CACHE, filename)
62
+ if not os.path.exists(dest_path.replace(".tar", "")):
63
+ download_weights(url, dest_path)
64
+
65
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ print(f"Using device: {self.device}")
67
+
68
+ # Move imports here and make them global
69
+ # This ensures model files are downloaded before importing mimicmotion modules
70
+ global MimicMotionPipeline, create_pipeline, save_to_mp4, get_video_pose, get_image_pose
71
+ from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline
72
+ from mimicmotion.utils.loader import create_pipeline
73
+ from mimicmotion.utils.utils import save_to_mp4
74
+ from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose
75
+
76
+ # Load config with new checkpoint as default
77
+ self.config = OmegaConf.create(
78
+ {
79
+ "base_model_path": "models/SVD/stable-video-diffusion-img2vid-xt-1-1",
80
+ "ckpt_path": "models/MimicMotion_1-1.pth",
81
+ }
82
+ )
83
+
84
+ # Create the pipeline with the new checkpoint
85
+ self.pipeline = create_pipeline(self.config, self.device)
86
+ self.current_checkpoint = "v1-1"
87
+ self.current_dtype = torch.get_default_dtype()
88
+
89
+ def predict(
90
+ self,
91
+ motion_video: Path = Input(
92
+ description="Reference video file containing the motion to be mimicked"
93
+ ),
94
+ appearance_image: Path = Input(
95
+ description="Reference image file for the appearance of the generated video"
96
+ ),
97
+ resolution: int = Input(
98
+ description="Height of the output video in pixels. Width is automatically calculated.",
99
+ default=576,
100
+ ge=64,
101
+ le=1024,
102
+ ),
103
+ chunk_size: int = Input(
104
+ description="Number of frames to generate in each processing chunk",
105
+ default=16,
106
+ ge=2,
107
+ ),
108
+ frames_overlap: int = Input(
109
+ description="Number of overlapping frames between chunks for smoother transitions",
110
+ default=6,
111
+ ge=0,
112
+ ),
113
+ denoising_steps: int = Input(
114
+ description="Number of denoising steps in the diffusion process. More steps can improve quality but increase processing time.",
115
+ default=25,
116
+ ge=1,
117
+ le=100,
118
+ ),
119
+ noise_strength: float = Input(
120
+ description="Strength of noise augmentation. Higher values add more variation but may reduce coherence with the reference.",
121
+ default=0.0,
122
+ ge=0.0,
123
+ le=1.0,
124
+ ),
125
+ guidance_scale: float = Input(
126
+ description="Strength of guidance towards the reference. Higher values adhere more closely to the reference but may reduce creativity.",
127
+ default=2.0,
128
+ ge=0.1,
129
+ le=10.0,
130
+ ),
131
+ sample_stride: int = Input(
132
+ description="Interval for sampling frames from the reference video. Higher values skip more frames.",
133
+ default=2,
134
+ ge=1,
135
+ ),
136
+ output_frames_per_second: int = Input(
137
+ description="Frames per second of the output video. Affects playback speed.",
138
+ default=15,
139
+ ge=1,
140
+ le=60,
141
+ ),
142
+ seed: int = Input(
143
+ description="Random seed. Leave blank to randomize the seed",
144
+ default=None,
145
+ ),
146
+ checkpoint_version: str = Input(
147
+ description="Choose the checkpoint version to use",
148
+ choices=["v1", "v1-1"],
149
+ default="v1-1",
150
+ ),
151
+ ) -> Path:
152
+ """Run a single prediction on the model"""
153
+
154
+ ref_video = motion_video
155
+ ref_image = appearance_image
156
+ num_frames = chunk_size
157
+ num_inference_steps = denoising_steps
158
+ noise_aug_strength = noise_strength
159
+ fps = output_frames_per_second
160
+ use_fp16 = True
161
+
162
+ if seed is None:
163
+ seed = int.from_bytes(os.urandom(2), "big")
164
+ print(f"Using seed: {seed}")
165
+
166
+ need_pipeline_update = False
167
+
168
+ # Check if we need to switch checkpoints
169
+ if checkpoint_version != self.current_checkpoint:
170
+ if checkpoint_version == "v1":
171
+ self.config.ckpt_path = "models/MimicMotion.pth"
172
+ else: # v1-1
173
+ self.config.ckpt_path = "models/MimicMotion_1-1.pth"
174
+ need_pipeline_update = True
175
+ self.current_checkpoint = checkpoint_version
176
+
177
+ # Check if we need to switch dtype
178
+ target_dtype = torch.float16 if use_fp16 else torch.float32
179
+ if target_dtype != self.current_dtype:
180
+ torch.set_default_dtype(target_dtype)
181
+ need_pipeline_update = True
182
+ self.current_dtype = target_dtype
183
+
184
+ # Update pipeline if needed
185
+ if need_pipeline_update:
186
+ print(
187
+ f"Updating pipeline with checkpoint: {self.config.ckpt_path} and dtype: {torch.get_default_dtype()}"
188
+ )
189
+ self.pipeline = create_pipeline(self.config, self.device)
190
+
191
+ print(f"Using checkpoint: {self.config.ckpt_path}")
192
+ print(f"Using dtype: {torch.get_default_dtype()}")
193
+
194
+ print(
195
+ f"[!] ({type(ref_video)}) ref_video={ref_video}, "
196
+ f"[!] ({type(ref_image)}) ref_image={ref_image}, "
197
+ f"[!] ({type(resolution)}) resolution={resolution}, "
198
+ f"[!] ({type(num_frames)}) num_frames={num_frames}, "
199
+ f"[!] ({type(frames_overlap)}) frames_overlap={frames_overlap}, "
200
+ f"[!] ({type(num_inference_steps)}) num_inference_steps={num_inference_steps}, "
201
+ f"[!] ({type(noise_aug_strength)}) noise_aug_strength={noise_aug_strength}, "
202
+ f"[!] ({type(guidance_scale)}) guidance_scale={guidance_scale}, "
203
+ f"[!] ({type(sample_stride)}) sample_stride={sample_stride}, "
204
+ f"[!] ({type(fps)}) fps={fps}, "
205
+ f"[!] ({type(seed)}) seed={seed}, "
206
+ f"[!] ({type(use_fp16)}) use_fp16={use_fp16}"
207
+ )
208
+
209
+ # Input validation
210
+ if not ref_video.exists():
211
+ raise ValueError(f"Reference video file does not exist: {ref_video}")
212
+ if not ref_image.exists():
213
+ raise ValueError(f"Reference image file does not exist: {ref_image}")
214
+
215
+ if resolution % 8 != 0:
216
+ raise ValueError(f"Resolution must be a multiple of 8, got {resolution}")
217
+
218
+ if resolution < 64 or resolution > 1024:
219
+ raise ValueError(
220
+ f"Resolution must be between 64 and 1024, got {resolution}"
221
+ )
222
+
223
+ if num_frames <= frames_overlap:
224
+ raise ValueError(
225
+ f"Number of frames ({num_frames}) must be greater than frames overlap ({frames_overlap})"
226
+ )
227
+
228
+ if num_frames < 2:
229
+ raise ValueError(f"Number of frames must be at least 2, got {num_frames}")
230
+
231
+ if frames_overlap < 0:
232
+ raise ValueError(
233
+ f"Frames overlap must be non-negative, got {frames_overlap}"
234
+ )
235
+
236
+ if num_inference_steps < 1 or num_inference_steps > 100:
237
+ raise ValueError(
238
+ f"Number of inference steps must be between 1 and 100, got {num_inference_steps}"
239
+ )
240
+
241
+ if noise_aug_strength < 0.0 or noise_aug_strength > 1.0:
242
+ raise ValueError(
243
+ f"Noise augmentation strength must be between 0.0 and 1.0, got {noise_aug_strength}"
244
+ )
245
+
246
+ if guidance_scale < 0.1 or guidance_scale > 10.0:
247
+ raise ValueError(
248
+ f"Guidance scale must be between 0.1 and 10.0, got {guidance_scale}"
249
+ )
250
+
251
+ if sample_stride < 1:
252
+ raise ValueError(f"Sample stride must be at least 1, got {sample_stride}")
253
+
254
+ if fps < 1 or fps > 60:
255
+ raise ValueError(f"FPS must be between 1 and 60, got {fps}")
256
+
257
+ try:
258
+ # Preprocess
259
+ pose_pixels, image_pixels = self.preprocess(
260
+ str(ref_video),
261
+ str(ref_image),
262
+ resolution=resolution,
263
+ sample_stride=sample_stride,
264
+ )
265
+
266
+ # Run pipeline
267
+ video_frames = self.run_pipeline(
268
+ image_pixels,
269
+ pose_pixels,
270
+ num_frames=num_frames,
271
+ frames_overlap=frames_overlap,
272
+ num_inference_steps=num_inference_steps,
273
+ noise_aug_strength=noise_aug_strength,
274
+ guidance_scale=guidance_scale,
275
+ seed=seed,
276
+ )
277
+
278
+ # Save output
279
+ output_path = f"/tmp/output_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"
280
+ save_to_mp4(video_frames, output_path, fps=fps)
281
+
282
+ return Path(output_path)
283
+
284
+ except Exception as e:
285
+ print(f"An error occurred during prediction: {str(e)}")
286
+ raise
287
+
288
+ def preprocess(self, video_path, image_path, resolution=576, sample_stride=2):
289
+ image_pixels = Image.open(image_path).convert("RGB")
290
+ image_pixels = pil_to_tensor(image_pixels) # (c, h, w)
291
+ h, w = image_pixels.shape[-2:]
292
+
293
+ if h > w:
294
+ w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64
295
+ else:
296
+ w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution
297
+
298
+ h_w_ratio = float(h) / float(w)
299
+ if h_w_ratio < h_target / w_target:
300
+ h_resize, w_resize = h_target, int(h_target / h_w_ratio)
301
+ else:
302
+ h_resize, w_resize = int(w_target * h_w_ratio), w_target
303
+
304
+ image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None)
305
+ image_pixels = center_crop(image_pixels, [h_target, w_target])
306
+ image_pixels = image_pixels.permute((1, 2, 0)).numpy()
307
+
308
+ image_pose = get_image_pose(image_pixels)
309
+ video_pose = get_video_pose(
310
+ video_path, image_pixels, sample_stride=sample_stride
311
+ )
312
+
313
+ pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose])
314
+ image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2))
315
+
316
+ return (
317
+ torch.from_numpy(pose_pixels.copy()) / 127.5 - 1,
318
+ torch.from_numpy(image_pixels) / 127.5 - 1,
319
+ )
320
+
321
+ def run_pipeline(
322
+ self,
323
+ image_pixels,
324
+ pose_pixels,
325
+ num_frames,
326
+ frames_overlap,
327
+ num_inference_steps,
328
+ noise_aug_strength,
329
+ guidance_scale,
330
+ seed,
331
+ ):
332
+ image_pixels = [
333
+ Image.fromarray(
334
+ (img.cpu().numpy().transpose(1, 2, 0) * 127.5 + 127.5).astype(np.uint8)
335
+ )
336
+ for img in image_pixels
337
+ ]
338
+ pose_pixels = pose_pixels.unsqueeze(0).to(self.device)
339
+
340
+ generator = torch.Generator(device=self.device)
341
+ generator.manual_seed(seed)
342
+
343
+ frames = self.pipeline(
344
+ image_pixels,
345
+ image_pose=pose_pixels,
346
+ num_frames=pose_pixels.size(1),
347
+ tile_size=num_frames,
348
+ tile_overlap=frames_overlap,
349
+ height=pose_pixels.shape[-2],
350
+ width=pose_pixels.shape[-1],
351
+ fps=7,
352
+ noise_aug_strength=noise_aug_strength,
353
+ num_inference_steps=num_inference_steps,
354
+ generator=generator,
355
+ min_guidance_scale=guidance_scale,
356
+ max_guidance_scale=guidance_scale,
357
+ decode_chunk_size=8,
358
+ output_type="pt",
359
+ device=self.device,
360
+ ).frames.cpu()
361
+
362
+ video_frames = (frames * 255.0).to(torch.uint8)
363
+ return video_frames[0, 1:] # Remove the first frame (reference image)