Spaces:
Runtime error
Runtime error
hao he
commited on
Commit
•
308c973
1
Parent(s):
88b231a
Add gradio codes for CameraCtrl with SVD-xt model
Browse files- LICENSE.txt +201 -0
- README.md +1 -1
- app.py +579 -0
- assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png +0 -0
- assets/example_condition_images/A_car_running_on_Mars..png +0 -0
- assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png +0 -0
- assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png +0 -0
- assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png +0 -0
- assets/example_condition_images/An_exploding_cheese_house..png +0 -0
- assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png +0 -0
- assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png +0 -0
- assets/example_condition_images/Leaves_are_falling_from_trees..png +0 -0
- assets/example_condition_images/Rocky_coastline_with_crashing_waves..png +0 -0
- assets/pose_files/0bf152ef84195293.txt +26 -0
- assets/pose_files/0c11dbe781b1c11c.txt +26 -0
- assets/pose_files/0c9b371cc6225682.txt +26 -0
- assets/pose_files/0f47577ab3441480.txt +26 -0
- assets/pose_files/0f68374b76390082.txt +26 -0
- assets/pose_files/2c80f9eb0d3b2bb4.txt +26 -0
- assets/pose_files/2f25826f0d0ef09a.txt +26 -0
- assets/pose_files/3f79dc32d575bcdc.txt +26 -0
- assets/pose_files/4a2d6753676df096.txt +26 -0
- assets/reference_videos/0bf152ef84195293.mp4 +0 -0
- assets/reference_videos/0c11dbe781b1c11c.mp4 +0 -0
- assets/reference_videos/0c9b371cc6225682.mp4 +0 -0
- assets/reference_videos/0f47577ab3441480.mp4 +0 -0
- assets/reference_videos/0f68374b76390082.mp4 +0 -0
- assets/reference_videos/2c80f9eb0d3b2bb4.mp4 +0 -0
- assets/reference_videos/2f25826f0d0ef09a.mp4 +0 -0
- assets/reference_videos/3f79dc32d575bcdc.mp4 +0 -0
- assets/reference_videos/4a2d6753676df096.mp4 +0 -0
- cameractrl/data/dataset.py +355 -0
- cameractrl/models/attention.py +65 -0
- cameractrl/models/attention_processor.py +591 -0
- cameractrl/models/motion_module.py +399 -0
- cameractrl/models/pose_adaptor.py +240 -0
- cameractrl/models/transformer_temporal.py +191 -0
- cameractrl/models/unet.py +587 -0
- cameractrl/models/unet_3d_blocks.py +461 -0
- cameractrl/pipelines/pipeline_animation.py +523 -0
- cameractrl/utils/convert_from_ckpt.py +556 -0
- cameractrl/utils/convert_lora_safetensor_to_diffusers.py +154 -0
- cameractrl/utils/util.py +148 -0
- configs/train_cameractrl/svd_320_576_cameractrl.yaml +87 -0
- configs/train_cameractrl/svdxt_320_576_cameractrl.yaml +88 -0
- inference_cameractrl.py +255 -0
- requirements.txt +20 -0
LICENSE.txt
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: CameraCtrl Svd Xt
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: CameraCtrl Svd Xt
|
3 |
+
emoji: 🎥
|
4 |
colorFrom: gray
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import gradio as gr
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import matplotlib as mpl
|
13 |
+
|
14 |
+
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
17 |
+
from inference_cameractrl import get_relative_pose, ray_condition, get_pipeline
|
18 |
+
from cameractrl.utils.util import save_videos_grid
|
19 |
+
|
20 |
+
cv2.setNumThreads(1)
|
21 |
+
mpl.use('agg')
|
22 |
+
|
23 |
+
#### Description ####
|
24 |
+
title = r"""<h1 align="center">CameraCtrl: Enabling Camera Control for Text-to-Video Generation</h1>"""
|
25 |
+
subtitle = r"""<h2 align="center">CameraCtrl Image2Video with <a href='https://arxiv.org/abs/2311.15127' target='_blank'> <b>Stable Video Diffusion (SVD)</b> </a>-xt <a href='https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt' target='_blank'> <b> model </b> </a> </h2>"""
|
26 |
+
description = r"""
|
27 |
+
<b>Official Gradio demo</b> for <a href='https://github.com/hehao13/CameraCtrl' target='_blank'><b>CameraCtrl: Enabling Camera Control for Text-to-Video Generation</b></a>.<br>
|
28 |
+
CameraCtrl is capable of precisely controlling the camera trajectory during the video generation process.<br>
|
29 |
+
Note that, with SVD-xt, CameraCtrl only support Image2Video now.<br>
|
30 |
+
"""
|
31 |
+
|
32 |
+
closing_words = r"""
|
33 |
+
|
34 |
+
---
|
35 |
+
|
36 |
+
If you are interested in this demo or CameraCtrl is helpful for you, please give us a ⭐ of the <a href='https://github.com/hehao13/CameraCtrl' target='_blank'> CameraCtrl</a> Github Repo !
|
37 |
+
[![GitHub Stars](https://img.shields.io/github/stars/hehao13/CameraCtrl
|
38 |
+
)](https://github.com/hehao13/CameraCtrl)
|
39 |
+
|
40 |
+
---
|
41 |
+
|
42 |
+
📝 **Citation**
|
43 |
+
<br>
|
44 |
+
If you find our paper or code is useful for your research, please consider citing:
|
45 |
+
```bibtex
|
46 |
+
@article{he2024cameractrl,
|
47 |
+
title={CameraCtrl: Enabling Camera Control for Text-to-Video Generation},
|
48 |
+
author={Hao He and Yinghao Xu and Yuwei Guo and Gordon Wetzstein and Bo Dai and Hongsheng Li and Ceyuan Yang},
|
49 |
+
journal={arXiv preprint arXiv:2404.02101},
|
50 |
+
year={2024}
|
51 |
+
}
|
52 |
+
```
|
53 |
+
|
54 |
+
📧 **Contact**
|
55 |
+
<br>
|
56 |
+
If you have any questions, please feel free to contact me at <b>1155203564@link.cuhk.edu.hk</b>.
|
57 |
+
|
58 |
+
**Acknowledgement**
|
59 |
+
<br>
|
60 |
+
We thank <a href='https://wzhouxiff.github.io/projects/MotionCtrl/' target='_blank'><b>MotionCtrl</b></a> and <a href='https://huggingface.co/spaces/lllyasviel/IC-Light' target='_blank'><b>IC-Light</b></a> for their gradio codes.<br>
|
61 |
+
"""
|
62 |
+
|
63 |
+
|
64 |
+
RESIZE_MODES = ['Resize then Center Crop', 'Directly resize']
|
65 |
+
CAMERA_TRAJECTORY_MODES = ["Provided Camera Trajectories", "Custom Camera Trajectories"]
|
66 |
+
height = 320
|
67 |
+
width = 576
|
68 |
+
num_frames = 25
|
69 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
70 |
+
|
71 |
+
config = "configs/train_cameractrl/svdxt_320_576_cameractrl.yaml"
|
72 |
+
model_id = "stabilityai/stable-video-diffusion-img2vid-xt"
|
73 |
+
ckpt = "checkpoints/CameraCtrl_svdxt.ckpt"
|
74 |
+
if not os.path.exists(ckpt):
|
75 |
+
os.makedirs("checkpoints", exist_ok=True)
|
76 |
+
os.system("wget -c https://huggingface.co/hehao13/CameraCtrl_svd/resolve/main/CameraCtrl_svdxt.ckpt?download=true")
|
77 |
+
os.system("mv CameraCtrl_svdxt.ckpt?download=true checkpoints/CameraCtrl_svdxt.ckpt")
|
78 |
+
model_config = OmegaConf.load(config)
|
79 |
+
|
80 |
+
|
81 |
+
pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], model_config['up_block_types'],
|
82 |
+
model_config['pose_encoder_kwargs'], model_config['attention_processor_kwargs'],
|
83 |
+
ckpt, True, device)
|
84 |
+
|
85 |
+
|
86 |
+
examples = [
|
87 |
+
[
|
88 |
+
"assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png",
|
89 |
+
"assets/pose_files/0bf152ef84195293.txt",
|
90 |
+
"Trajectory 1"
|
91 |
+
],
|
92 |
+
[
|
93 |
+
"assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png",
|
94 |
+
"assets/pose_files/0c9b371cc6225682.txt",
|
95 |
+
"Trajectory 2"
|
96 |
+
],
|
97 |
+
[
|
98 |
+
"assets/example_condition_images/Rocky_coastline_with_crashing_waves..png",
|
99 |
+
"assets/pose_files/0c11dbe781b1c11c.txt",
|
100 |
+
"Trajectory 3"
|
101 |
+
],
|
102 |
+
[
|
103 |
+
"assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png",
|
104 |
+
"assets/pose_files/0f47577ab3441480.txt",
|
105 |
+
"Trajectory 4"
|
106 |
+
],
|
107 |
+
[
|
108 |
+
"assets/example_condition_images/An_exploding_cheese_house..png",
|
109 |
+
"assets/pose_files/0f47577ab3441480.txt",
|
110 |
+
"Trajectory 4"
|
111 |
+
],
|
112 |
+
[
|
113 |
+
"assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png",
|
114 |
+
"assets/pose_files/0f68374b76390082.txt",
|
115 |
+
"Trajectory 5"
|
116 |
+
],
|
117 |
+
[
|
118 |
+
"assets/example_condition_images/Leaves_are_falling_from_trees..png",
|
119 |
+
"assets/pose_files/2c80f9eb0d3b2bb4.txt",
|
120 |
+
"Trajectory 6"
|
121 |
+
],
|
122 |
+
[
|
123 |
+
"assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png",
|
124 |
+
"assets/pose_files/2f25826f0d0ef09a.txt",
|
125 |
+
"Trajectory 7"
|
126 |
+
],
|
127 |
+
[
|
128 |
+
"assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png",
|
129 |
+
"assets/pose_files/3f79dc32d575bcdc.txt",
|
130 |
+
"Trajectory 8"
|
131 |
+
],
|
132 |
+
[
|
133 |
+
"assets/example_condition_images/A_car_running_on_Mars..png",
|
134 |
+
"assets/pose_files/4a2d6753676df096.txt",
|
135 |
+
"Trajectory 9"
|
136 |
+
],
|
137 |
+
]
|
138 |
+
|
139 |
+
|
140 |
+
class Camera(object):
|
141 |
+
def __init__(self, entry):
|
142 |
+
fx, fy, cx, cy = entry[1:5]
|
143 |
+
self.fx = fx
|
144 |
+
self.fy = fy
|
145 |
+
self.cx = cx
|
146 |
+
self.cy = cy
|
147 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
148 |
+
w2c_mat_4x4 = np.eye(4)
|
149 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
150 |
+
self.w2c_mat = w2c_mat_4x4
|
151 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
152 |
+
|
153 |
+
|
154 |
+
class CameraPoseVisualizer:
|
155 |
+
def __init__(self, xlim, ylim, zlim):
|
156 |
+
self.fig = plt.figure(figsize=(18, 7))
|
157 |
+
self.ax = self.fig.add_subplot(projection='3d')
|
158 |
+
self.plotly_data = None # plotly data traces
|
159 |
+
self.ax.set_aspect("auto")
|
160 |
+
self.ax.set_xlim(xlim)
|
161 |
+
self.ax.set_ylim(ylim)
|
162 |
+
self.ax.set_zlim(zlim)
|
163 |
+
self.ax.set_xlabel('x')
|
164 |
+
self.ax.set_ylabel('y')
|
165 |
+
self.ax.set_zlabel('z')
|
166 |
+
|
167 |
+
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9 / 16, base_xval=1, zval=3):
|
168 |
+
vertex_std = np.array([[0, 0, 0, 1],
|
169 |
+
[base_xval, -base_xval * hw_ratio, zval, 1],
|
170 |
+
[base_xval, base_xval * hw_ratio, zval, 1],
|
171 |
+
[-base_xval, base_xval * hw_ratio, zval, 1],
|
172 |
+
[-base_xval, -base_xval * hw_ratio, zval, 1]])
|
173 |
+
vertex_transformed = vertex_std @ extrinsic.T
|
174 |
+
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
|
175 |
+
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
|
176 |
+
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
|
177 |
+
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
|
178 |
+
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1],
|
179 |
+
vertex_transformed[4, :-1]]]
|
180 |
+
|
181 |
+
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
|
182 |
+
|
183 |
+
self.ax.add_collection3d(
|
184 |
+
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
|
185 |
+
|
186 |
+
def colorbar(self, max_frame_length):
|
187 |
+
cmap = mpl.cm.rainbow
|
188 |
+
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
|
189 |
+
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax, orientation='vertical',
|
190 |
+
label='Frame Indexes')
|
191 |
+
|
192 |
+
def show(self):
|
193 |
+
plt.title('Camera Trajectory')
|
194 |
+
plt.show()
|
195 |
+
|
196 |
+
|
197 |
+
def get_c2w(w2cs):
|
198 |
+
target_cam_c2w = np.array([
|
199 |
+
[1, 0, 0, 0],
|
200 |
+
[0, 1, 0, 0],
|
201 |
+
[0, 0, 1, 0],
|
202 |
+
[0, 0, 0, 1]
|
203 |
+
])
|
204 |
+
abs2rel = target_cam_c2w @ w2cs[0]
|
205 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
|
206 |
+
camera_positions = np.asarray([c2w[:3, 3] for c2w in ret_poses]) # [n_frame, 3]
|
207 |
+
position_distances = [camera_positions[i] - camera_positions[i - 1] for i in range(1, len(camera_positions))]
|
208 |
+
xyz_max = np.max(camera_positions, axis=0)
|
209 |
+
xyz_min = np.min(camera_positions, axis=0)
|
210 |
+
xyz_ranges = xyz_max - xyz_min # [3, ]
|
211 |
+
max_range = np.max(xyz_ranges)
|
212 |
+
expected_xyz_ranges = 1
|
213 |
+
scale_ratio = expected_xyz_ranges / max_range
|
214 |
+
scaled_position_distances = [dis * scale_ratio for dis in position_distances] # [n_frame - 1]
|
215 |
+
scaled_camera_positions = [camera_positions[0], ]
|
216 |
+
scaled_camera_positions.extend([camera_positions[0] + np.sum(np.asarray(scaled_position_distances[:i]), axis=0)
|
217 |
+
for i in range(1, len(camera_positions))])
|
218 |
+
ret_poses = [np.concatenate(
|
219 |
+
(np.concatenate((ori_pose[:3, :3], cam_position[:, None]), axis=1), np.asarray([0, 0, 0, 1])[None]), axis=0)
|
220 |
+
for ori_pose, cam_position in zip(ret_poses, scaled_camera_positions)]
|
221 |
+
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
|
222 |
+
ret_poses = [transform_matrix @ x for x in ret_poses]
|
223 |
+
return np.array(ret_poses, dtype=np.float32)
|
224 |
+
|
225 |
+
|
226 |
+
def visualize_trajectory(trajectory_file):
|
227 |
+
with open(trajectory_file, 'r') as f:
|
228 |
+
poses = f.readlines()
|
229 |
+
w2cs = [np.asarray([float(p) for p in pose.strip().split(' ')[7:]]).reshape(3, 4) for pose in poses[1:]]
|
230 |
+
num_frames = len(w2cs)
|
231 |
+
last_row = np.zeros((1, 4))
|
232 |
+
last_row[0, -1] = 1.0
|
233 |
+
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
|
234 |
+
c2ws = get_c2w(w2cs)
|
235 |
+
visualizer = CameraPoseVisualizer([-1.2, 1.2], [-1.2, 1.2], [-1.2, 1.2])
|
236 |
+
for frame_idx, c2w in enumerate(c2ws):
|
237 |
+
visualizer.extrinsic2pyramid(c2w, frame_idx / num_frames, hw_ratio=9 / 16, base_xval=0.02, zval=0.1)
|
238 |
+
visualizer.colorbar(num_frames)
|
239 |
+
return visualizer.fig
|
240 |
+
|
241 |
+
|
242 |
+
vis_traj = visualize_trajectory('assets/pose_files/0bf152ef84195293.txt')
|
243 |
+
|
244 |
+
|
245 |
+
@torch.inference_mode()
|
246 |
+
def process_input_image(input_image, resize_mode):
|
247 |
+
global height, width
|
248 |
+
expected_hw_ratio = height / width
|
249 |
+
inp_w, inp_h = input_image.size
|
250 |
+
inp_hw_ratio = inp_h / inp_w
|
251 |
+
|
252 |
+
if inp_hw_ratio > expected_hw_ratio:
|
253 |
+
resized_height = inp_hw_ratio * width
|
254 |
+
resized_width = width
|
255 |
+
else:
|
256 |
+
resized_height = height
|
257 |
+
resized_width = height / inp_hw_ratio
|
258 |
+
resized_image = F.resize(input_image, size=[resized_height, resized_width])
|
259 |
+
|
260 |
+
if resize_mode == RESIZE_MODES[0]:
|
261 |
+
return_image = F.center_crop(resized_image, output_size=[height, width])
|
262 |
+
else:
|
263 |
+
return_image = resized_image
|
264 |
+
|
265 |
+
return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), gr.update(
|
266 |
+
visible=True), gr.update(visible=True), gr.update(visible=True)
|
267 |
+
|
268 |
+
|
269 |
+
def update_camera_trajectories(trajectory_mode):
|
270 |
+
if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]:
|
271 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
272 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
273 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
274 |
+
elif trajectory_mode == CAMERA_TRAJECTORY_MODES[1]:
|
275 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
276 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
277 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
278 |
+
|
279 |
+
|
280 |
+
def update_camera_args(trajectory_mode, provided_camera_trajectory, customized_trajectory_file):
|
281 |
+
if trajectory_mode == CAMERA_TRAJECTORY_MODES[0]:
|
282 |
+
res = "Provided " + str(provided_camera_trajectory)
|
283 |
+
else:
|
284 |
+
if customized_trajectory_file is None:
|
285 |
+
res = " "
|
286 |
+
else:
|
287 |
+
res = f"Customized trajectory file {customized_trajectory_file.name.split('/')[-1]}"
|
288 |
+
return res
|
289 |
+
|
290 |
+
|
291 |
+
def update_camera_args_reset():
|
292 |
+
return " "
|
293 |
+
|
294 |
+
|
295 |
+
def update_trajectory_vis_plot(camera_trajectory_args, provided_camera_trajectory, customized_trajectory_file):
|
296 |
+
if 'Provided' in camera_trajectory_args:
|
297 |
+
if provided_camera_trajectory == "Trajectory 1":
|
298 |
+
trajectory_file_path = "assets/pose_files/0bf152ef84195293.txt"
|
299 |
+
elif provided_camera_trajectory == "Trajectory 2":
|
300 |
+
trajectory_file_path = "assets/pose_files/0c9b371cc6225682.txt"
|
301 |
+
elif provided_camera_trajectory == "Trajectory 3":
|
302 |
+
trajectory_file_path = "assets/pose_files/0c11dbe781b1c11c.txt"
|
303 |
+
elif provided_camera_trajectory == "Trajectory 4":
|
304 |
+
trajectory_file_path = "assets/pose_files/0f47577ab3441480.txt"
|
305 |
+
elif provided_camera_trajectory == "Trajectory 5":
|
306 |
+
trajectory_file_path = "assets/pose_files/0f68374b76390082.txt"
|
307 |
+
elif provided_camera_trajectory == "Trajectory 6":
|
308 |
+
trajectory_file_path = "assets/pose_files/2c80f9eb0d3b2bb4.txt"
|
309 |
+
elif provided_camera_trajectory == "Trajectory 7":
|
310 |
+
trajectory_file_path = "assets/pose_files/2f25826f0d0ef09a.txt"
|
311 |
+
elif provided_camera_trajectory == "Trajectory 8":
|
312 |
+
trajectory_file_path = "assets/pose_files/3f79dc32d575bcdc.txt"
|
313 |
+
else:
|
314 |
+
trajectory_file_path = "assets/pose_files/4a2d6753676df096.txt"
|
315 |
+
else:
|
316 |
+
trajectory_file_path = customized_trajectory_file.name
|
317 |
+
vis_traj = visualize_trajectory(trajectory_file_path)
|
318 |
+
return gr.update(visible=True), vis_traj, gr.update(visible=True), gr.update(visible=True), \
|
319 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
320 |
+
gr.update(visible=True), gr.update(visible=True), trajectory_file_path
|
321 |
+
|
322 |
+
|
323 |
+
def update_set_button():
|
324 |
+
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
325 |
+
|
326 |
+
|
327 |
+
def update_buttons_for_example(example_image, example_traj_path, provided_traj_name):
|
328 |
+
global height, width
|
329 |
+
return_image = example_image
|
330 |
+
return gr.update(visible=True, value=return_image, height=height, width=width), gr.update(visible=True), \
|
331 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
332 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), \
|
333 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), \
|
334 |
+
gr.update(visible=True)
|
335 |
+
|
336 |
+
@spaces.GPU
|
337 |
+
@torch.inference_mode()
|
338 |
+
def sample_video(condition_image, trajectory_file, num_inference_step, min_guidance_scale, max_guidance_scale, fps_id, seed):
|
339 |
+
global height, width, num_frames, device, pipeline
|
340 |
+
with open(trajectory_file, 'r') as f:
|
341 |
+
poses = f.readlines()
|
342 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
343 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
344 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
345 |
+
sample_wh_ratio = width / height
|
346 |
+
pose_wh_ratio = cam_params[0].fy / cam_params[0].fx
|
347 |
+
if pose_wh_ratio > sample_wh_ratio:
|
348 |
+
resized_ori_w = height * pose_wh_ratio
|
349 |
+
for cam_param in cam_params:
|
350 |
+
cam_param.fx = resized_ori_w * cam_param.fx / width
|
351 |
+
else:
|
352 |
+
resized_ori_h = width / pose_wh_ratio
|
353 |
+
for cam_param in cam_params:
|
354 |
+
cam_param.fy = resized_ori_h * cam_param.fy / height
|
355 |
+
intrinsic = np.asarray([[cam_param.fx * width,
|
356 |
+
cam_param.fy * height,
|
357 |
+
cam_param.cx * width,
|
358 |
+
cam_param.cy * height]
|
359 |
+
for cam_param in cam_params], dtype=np.float32)
|
360 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
361 |
+
c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True)
|
362 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
363 |
+
plucker_embedding = ray_condition(K, c2ws, height, width, device='cpu') # b f h w 6
|
364 |
+
plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device)
|
365 |
+
|
366 |
+
generator = torch.Generator(device=device)
|
367 |
+
generator.manual_seed(int(seed))
|
368 |
+
|
369 |
+
with torch.no_grad():
|
370 |
+
sample = pipeline(
|
371 |
+
image=condition_image,
|
372 |
+
pose_embedding=plucker_embedding,
|
373 |
+
height=height,
|
374 |
+
width=width,
|
375 |
+
num_frames=num_frames,
|
376 |
+
num_inference_steps=num_inference_step,
|
377 |
+
min_guidance_scale=min_guidance_scale,
|
378 |
+
max_guidance_scale=max_guidance_scale,
|
379 |
+
fps=fps_id,
|
380 |
+
do_image_process=True,
|
381 |
+
generator=generator,
|
382 |
+
output_type='pt'
|
383 |
+
).frames[0].transpose(0, 1).cpu()
|
384 |
+
|
385 |
+
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
386 |
+
save_videos_grid(sample[None], temporal_video_path, rescale=False)
|
387 |
+
|
388 |
+
return temporal_video_path
|
389 |
+
|
390 |
+
|
391 |
+
def main(args):
|
392 |
+
demo = gr.Blocks().queue()
|
393 |
+
with demo:
|
394 |
+
gr.Markdown(title)
|
395 |
+
gr.Markdown(subtitle)
|
396 |
+
gr.Markdown(description)
|
397 |
+
|
398 |
+
with gr.Column():
|
399 |
+
# step1: Input condition image
|
400 |
+
step1_title = gr.Markdown("---\n## Step 1: Input an Image", show_label=False, visible=True)
|
401 |
+
step1_dec = gr.Markdown(f"\n 1. Upload an Image by `Drag` or Click `Upload Image`; \
|
402 |
+
\n 2. Click `{RESIZE_MODES[0]}` or `{RESIZE_MODES[1]}` to select the image resize mode. \
|
403 |
+
\n - `{RESIZE_MODES[0]}`: First resize the input image, then center crop it into the resolution of 320 x 576. \
|
404 |
+
\n - `{RESIZE_MODES[1]}`: Only resize the input image, and keep the original aspect ratio.",
|
405 |
+
show_label=False, visible=True)
|
406 |
+
with gr.Row(equal_height=True):
|
407 |
+
with gr.Column(scale=2):
|
408 |
+
input_image = gr.Image(type='pil', interactive=True, elem_id='condition_image',
|
409 |
+
elem_classes='image',
|
410 |
+
visible=True)
|
411 |
+
with gr.Row():
|
412 |
+
resize_crop_button = gr.Button(RESIZE_MODES[0], visible=True)
|
413 |
+
directly_resize_button = gr.Button(RESIZE_MODES[1], visible=True)
|
414 |
+
with gr.Column(scale=2):
|
415 |
+
processed_image = gr.Image(type='pil', interactive=False, elem_id='processed_image',
|
416 |
+
elem_classes='image', visible=False)
|
417 |
+
|
418 |
+
# step2: Select camera trajectory
|
419 |
+
step2_camera_trajectory = gr.Markdown("---\n## Step 2: Select the camera trajectory", show_label=False,
|
420 |
+
visible=False)
|
421 |
+
step2_camera_trajectory_des = gr.Markdown(f"\n - `{CAMERA_TRAJECTORY_MODES[0]}`: Including 9 camera trajectories extracted from the test set of RealEstate10K dataset, each has 25 frames. \
|
422 |
+
\n - `{CAMERA_TRAJECTORY_MODES[1]}`: You can provide the customized camera trajectories in the txt file.",
|
423 |
+
show_label=False, visible=False)
|
424 |
+
with gr.Row(equal_height=True):
|
425 |
+
provide_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[0], visible=False)
|
426 |
+
customized_trajectory_button = gr.Button(CAMERA_TRAJECTORY_MODES[1], visible=False)
|
427 |
+
with gr.Row():
|
428 |
+
with gr.Column():
|
429 |
+
provided_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[0]}", show_label=False,
|
430 |
+
visible=False)
|
431 |
+
provided_camera_trajectory_des = gr.Markdown(f"\n 1. Click one of the provide camera trajectories, such as `Trajectory 1`; \
|
432 |
+
\n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \
|
433 |
+
\n 3. Click `Reset Trajectory` to reset the camera trajectory. ",
|
434 |
+
show_label=False, visible=False)
|
435 |
+
|
436 |
+
customized_camera_trajectory = gr.Markdown(f"---\n### {CAMERA_TRAJECTORY_MODES[1]}",
|
437 |
+
show_label=False,
|
438 |
+
visible=False)
|
439 |
+
customized_run_status = gr.Markdown(f"\n 1. Input the txt file containing camera trajectory. \
|
440 |
+
\n 2. Click `Visualize Trajectory` to visualize the camera trajectory; \
|
441 |
+
\n 3. Click `Reset Trajectory` to reset the camera trajectory. ",
|
442 |
+
show_label=False, visible=False)
|
443 |
+
|
444 |
+
with gr.Row():
|
445 |
+
provided_trajectories = gr.Dropdown(
|
446 |
+
["Trajectory 1", "Trajectory 2", "Trajectory 3", "Trajectory 4", "Trajectory 5",
|
447 |
+
"Trajectory 6", "Trajectory 7", "Trajectory 8", "Trajectory 9"],
|
448 |
+
label="Provided Trajectories", interactive=True, visible=False)
|
449 |
+
with gr.Row():
|
450 |
+
customized_camera_trajectory_file = gr.File(
|
451 |
+
label="Upload customized camera trajectory (in .txt format).", visible=False, interactive=True)
|
452 |
+
|
453 |
+
with gr.Row():
|
454 |
+
camera_args = gr.Textbox(value=" ", label="Camera Trajectory Name", visible=False)
|
455 |
+
camera_trajectory_path = gr.Textbox(value=" ", visible=False)
|
456 |
+
|
457 |
+
with gr.Row():
|
458 |
+
camera_trajectory_vis = gr.Button(value="Visualize Camera Trajectory", visible=False)
|
459 |
+
camera_trajectory_reset = gr.Button(value="Reset Camera Trajectory", visible=False)
|
460 |
+
with gr.Column():
|
461 |
+
vis_camera_trajectory = gr.Plot(vis_traj, label='Camera Trajectory', visible=False)
|
462 |
+
|
463 |
+
# step3: Set inference parameters
|
464 |
+
with gr.Row():
|
465 |
+
with gr.Column():
|
466 |
+
step3_title = gr.Markdown(f"---\n## Step3: Setting the inference hyper-parameters.", visible=False)
|
467 |
+
step3_des = gr.Markdown(
|
468 |
+
f"\n 1. Set the mumber of inference step; \
|
469 |
+
\n 2. Set the seed; \
|
470 |
+
\n 3. Set the minimum guidance scale and the maximum guidance scale; \
|
471 |
+
\n 4. Set the fps; \
|
472 |
+
\n - Please refer to the SVD paper for the meaning of the last three parameter",
|
473 |
+
visible=False)
|
474 |
+
with gr.Row():
|
475 |
+
with gr.Column():
|
476 |
+
num_inference_steps = gr.Number(value=25, label='Number Inference Steps', step=1, interactive=True,
|
477 |
+
visible=False)
|
478 |
+
with gr.Column():
|
479 |
+
seed = gr.Number(value=42, label='Seed', minimum=1, interactive=True, visible=False, step=1)
|
480 |
+
with gr.Column():
|
481 |
+
min_guidance_scale = gr.Number(value=1.0, label='Minimum Guidance Scale', minimum=1.0, step=0.5,
|
482 |
+
interactive=True, visible=False)
|
483 |
+
with gr.Column():
|
484 |
+
max_guidance_scale = gr.Number(value=3.0, label='Maximum Guidance Scale', minimum=1.0, step=0.5,
|
485 |
+
interactive=True, visible=False)
|
486 |
+
with gr.Column():
|
487 |
+
fps = gr.Number(value=7, label='FPS', minimum=1, step=1, interactive=True, visible=False)
|
488 |
+
with gr.Column():
|
489 |
+
_ = gr.Button("Seed", visible=False)
|
490 |
+
with gr.Column():
|
491 |
+
_ = gr.Button("Seed", visible=False)
|
492 |
+
with gr.Column():
|
493 |
+
_ = gr.Button("Seed", visible=False)
|
494 |
+
with gr.Row():
|
495 |
+
with gr.Column():
|
496 |
+
_ = gr.Button("Set", visible=False)
|
497 |
+
with gr.Column():
|
498 |
+
set_button = gr.Button("Set", visible=False)
|
499 |
+
with gr.Column():
|
500 |
+
_ = gr.Button("Set", visible=False)
|
501 |
+
|
502 |
+
# step 4: Generate video
|
503 |
+
with gr.Row():
|
504 |
+
with gr.Column():
|
505 |
+
step4_title = gr.Markdown("---\n## Step4 Generating video", show_label=False, visible=False)
|
506 |
+
step4_des = gr.Markdown(f"\n - Click the `Start generation !` button to generate the video.; \
|
507 |
+
\n - If the content of generated video is not very aligned with the condition image, try to increase the `Minimum Guidance Scale` and `Maximum Guidance Scale`. \
|
508 |
+
\n - If the generated videos are distored, try to increase `FPS`.",
|
509 |
+
visible=False)
|
510 |
+
start_button = gr.Button(value="Start generation !", visible=False)
|
511 |
+
with gr.Column():
|
512 |
+
generate_video = gr.Video(value=None, label="Generate Video", visible=False)
|
513 |
+
resize_crop_button.click(fn=process_input_image, inputs=[input_image, resize_crop_button],
|
514 |
+
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des,
|
515 |
+
provide_trajectory_button, customized_trajectory_button])
|
516 |
+
directly_resize_button.click(fn=process_input_image, inputs=[input_image, directly_resize_button],
|
517 |
+
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des,
|
518 |
+
provide_trajectory_button, customized_trajectory_button])
|
519 |
+
provide_trajectory_button.click(fn=update_camera_trajectories, inputs=[provide_trajectory_button],
|
520 |
+
outputs=[provided_camera_trajectory, provided_camera_trajectory_des,
|
521 |
+
provided_trajectories,
|
522 |
+
customized_camera_trajectory, customized_run_status,
|
523 |
+
customized_camera_trajectory_file,
|
524 |
+
camera_args, camera_trajectory_vis, camera_trajectory_reset])
|
525 |
+
customized_trajectory_button.click(fn=update_camera_trajectories, inputs=[customized_trajectory_button],
|
526 |
+
outputs=[provided_camera_trajectory, provided_camera_trajectory_des,
|
527 |
+
provided_trajectories,
|
528 |
+
customized_camera_trajectory, customized_run_status,
|
529 |
+
customized_camera_trajectory_file,
|
530 |
+
camera_args, camera_trajectory_vis, camera_trajectory_reset])
|
531 |
+
|
532 |
+
provided_trajectories.change(fn=update_camera_args, inputs=[provide_trajectory_button, provided_trajectories, customized_camera_trajectory_file],
|
533 |
+
outputs=[camera_args])
|
534 |
+
customized_camera_trajectory_file.change(fn=update_camera_args, inputs=[customized_trajectory_button, provided_trajectories, customized_camera_trajectory_file],
|
535 |
+
outputs=[camera_args])
|
536 |
+
camera_trajectory_reset.click(fn=update_camera_args_reset, inputs=None, outputs=[camera_args])
|
537 |
+
camera_trajectory_vis.click(fn=update_trajectory_vis_plot, inputs=[camera_args, provided_trajectories, customized_camera_trajectory_file],
|
538 |
+
outputs=[vis_camera_trajectory, vis_camera_trajectory, step3_title, step3_des,
|
539 |
+
num_inference_steps, min_guidance_scale, max_guidance_scale, fps,
|
540 |
+
seed, set_button, camera_trajectory_path])
|
541 |
+
set_button.click(fn=update_set_button, inputs=None, outputs=[step4_title, step4_des, start_button, generate_video])
|
542 |
+
start_button.click(fn=sample_video, inputs=[processed_image, camera_trajectory_path, num_inference_steps,
|
543 |
+
min_guidance_scale, max_guidance_scale, fps, seed],
|
544 |
+
outputs=[generate_video])
|
545 |
+
|
546 |
+
# set example
|
547 |
+
gr.Markdown("## Examples")
|
548 |
+
gr.Markdown("\n Choosing the one of the following examples to get a quick start, by selecting an example, "
|
549 |
+
"we will set the condition image and camera trajectory automatically. "
|
550 |
+
"Then, you can click the `Visualize Camera Trajectory` button to visualize the camera trajectory.")
|
551 |
+
gr.Examples(
|
552 |
+
fn=update_buttons_for_example,
|
553 |
+
run_on_click=True,
|
554 |
+
cache_examples=False,
|
555 |
+
examples=examples,
|
556 |
+
inputs=[input_image, camera_args, provided_trajectories],
|
557 |
+
outputs=[processed_image, step2_camera_trajectory, step2_camera_trajectory_des, provide_trajectory_button,
|
558 |
+
customized_trajectory_button,
|
559 |
+
provided_camera_trajectory, provided_camera_trajectory_des, provided_trajectories,
|
560 |
+
customized_camera_trajectory, customized_run_status, customized_camera_trajectory_file,
|
561 |
+
camera_args, camera_trajectory_vis, camera_trajectory_reset]
|
562 |
+
)
|
563 |
+
with gr.Row():
|
564 |
+
gr.Markdown(closing_words)
|
565 |
+
|
566 |
+
demo.launch(**args)
|
567 |
+
|
568 |
+
|
569 |
+
if __name__ == '__main__':
|
570 |
+
parser = argparse.ArgumentParser()
|
571 |
+
parser.add_argument('--listen', default='0.0.0.0')
|
572 |
+
parser.add_argument('--broswer', action='store_true')
|
573 |
+
parser.add_argument('--share', action='store_true')
|
574 |
+
args = parser.parse_args()
|
575 |
+
|
576 |
+
launch_kwargs = {'server_name': args.listen,
|
577 |
+
'inbrowser': args.broswer,
|
578 |
+
'share': args.share}
|
579 |
+
main(launch_kwargs)
|
assets/example_condition_images/A_beautiful_fluffy_domestic_hen_sitting_on_white_eggs_in_a_brown_nest,_eggs_are_under_the_hen..png
ADDED
assets/example_condition_images/A_car_running_on_Mars..png
ADDED
assets/example_condition_images/A_lion_standing_on_a_surfboard_in_the_ocean..png
ADDED
assets/example_condition_images/A_serene_mountain_lake_at_sunrise,_with_mist_hovering_over_the_water..png
ADDED
assets/example_condition_images/A_tiny_finch_on_a_branch_with_spring_flowers_on_background..png
ADDED
assets/example_condition_images/An_exploding_cheese_house..png
ADDED
assets/example_condition_images/Dolphins_leaping_out_of_the_ocean_at_sunset..png
ADDED
assets/example_condition_images/Fireworks_display_illuminating_the_night_sky..png
ADDED
assets/example_condition_images/Leaves_are_falling_from_trees..png
ADDED
assets/example_condition_images/Rocky_coastline_with_crashing_waves..png
ADDED
assets/pose_files/0bf152ef84195293.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=QShWPZxTDoE
|
2 |
+
157323991 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.950234294 0.023969267 -0.310612619 -0.058392330 -0.025083920 0.999685287 0.000406042 0.179560758 0.310524613 0.007405547 0.950536489 -0.411621285
|
3 |
+
157490824 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.932122767 0.029219138 -0.360961705 0.019157260 -0.030671693 0.999528050 0.001705339 0.195243598 0.360841185 0.009481722 0.932579100 -0.489249695
|
4 |
+
157657658 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.912891090 0.034948215 -0.406704396 0.093606521 -0.036429971 0.999327779 0.004101569 0.203909523 0.406574339 0.011071944 0.913550615 -0.570709379
|
5 |
+
157824491 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.892021954 0.039648205 -0.450249761 0.174752186 -0.041337918 0.999126673 0.006083843 0.206605029 0.450097769 0.013185467 0.892881930 -0.657519766
|
6 |
+
157991325 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.870897233 0.043891508 -0.489501357 0.266759117 -0.046065997 0.998909414 0.007609563 0.208293300 0.489301533 0.015922222 0.871969342 -0.739918788
|
7 |
+
158158158 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.850054264 0.048434701 -0.524463415 0.371990684 -0.051002879 0.998652756 0.009560689 0.215520371 0.524219871 0.018622037 0.851379335 -0.814489669
|
8 |
+
158358358 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.823578537 0.052956820 -0.564724684 0.498689894 -0.055313200 0.998385012 0.012955925 0.224118528 0.564498782 0.020566508 0.825177670 -0.889946292
|
9 |
+
158525192 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.801249743 0.056292553 -0.595676124 0.608660065 -0.058902908 0.998149574 0.015096202 0.223320416 0.595423639 0.022991227 0.803082883 -0.943733076
|
10 |
+
158692025 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.780003667 0.059620168 -0.622928321 0.726968666 -0.062449891 0.997897983 0.017311305 0.217967188 0.622651041 0.025398925 0.782087326 -1.002211444
|
11 |
+
158858859 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.758300304 0.062706254 -0.648882568 0.862137737 -0.066019125 0.997632504 0.019256916 0.210050766 0.648553908 0.028236136 0.760644853 -1.055941415
|
12 |
+
159025692 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.733386099 0.066376433 -0.676564097 1.014642875 -0.069476441 0.997329056 0.022534581 0.204417168 0.676252782 0.030478716 0.736038864 -1.100931176
|
13 |
+
159192526 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.703763664 0.069719747 -0.707004845 1.176046236 -0.073094003 0.996997535 0.025557835 0.198280199 0.706663966 0.033691138 0.706746757 -1.127059555
|
14 |
+
159392726 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.662238955 0.074904546 -0.745539367 1.361111617 -0.076822795 0.996534884 0.031882886 0.176548885 0.745344102 0.036160327 0.665698588 -1.136046987
|
15 |
+
159559560 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.629729092 0.078562595 -0.772831917 1.488353738 -0.081706621 0.996052980 0.034676969 0.152860218 0.772505820 0.041308392 0.633662641 -1.137729720
|
16 |
+
159726393 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.602676034 0.081962064 -0.793765604 1.594849532 -0.083513811 0.995727122 0.039407197 0.137400253 0.793603837 0.042540617 0.606945813 -1.154423412
|
17 |
+
159893227 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.580205023 0.084535435 -0.810071528 1.693660542 -0.086233690 0.995384574 0.042109925 0.134338657 0.809892535 0.045423068 0.584816933 -1.189997045
|
18 |
+
160060060 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.559533417 0.086548090 -0.824276507 1.785956560 -0.089039005 0.995054126 0.044038296 0.143407250 0.824011147 0.048751865 0.564472198 -1.233530509
|
19 |
+
160226894 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.539407372 0.088928543 -0.837335885 1.876278781 -0.091476299 0.994710982 0.046713885 0.159821683 0.837061405 0.051398575 0.544689238 -1.287939732
|
20 |
+
160427094 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.515480161 0.090795092 -0.852077723 1.979818582 -0.093906343 0.994367242 0.049146701 0.181896137 0.851740420 0.054681353 0.521102846 -1.359775674
|
21 |
+
160593927 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.497423410 0.091656610 -0.862652302 2.062552118 -0.095314465 0.994156837 0.050668620 0.194005458 0.862255812 0.057019483 0.503253102 -1.415121326
|
22 |
+
160760761 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.484208912 0.092620693 -0.870036304 2.136687359 -0.096262053 0.993984103 0.052242137 0.204385655 0.869640946 0.058455370 0.490211815 -1.477987717
|
23 |
+
160927594 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.475284129 0.093297184 -0.874871790 2.200792438 -0.096743606 0.993874133 0.053430639 0.209217395 0.874497354 0.059243519 0.481398523 -1.547068315
|
24 |
+
161094428 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.468848795 0.093707815 -0.878293574 2.268227083 -0.097071946 0.993799806 0.054212786 0.208793720 0.877928138 0.059840068 0.475038230 -1.634971335
|
25 |
+
161261261 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.463450164 0.093811318 -0.881143212 2.339123750 -0.097783640 0.993721604 0.054366294 0.224513862 0.880711257 0.060965322 0.469713658 -1.732136350
|
26 |
+
161461461 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.458983690 0.093715429 -0.883488178 2.426412787 -0.098171033 0.993681431 0.054402962 0.253829726 0.883004189 0.061762877 0.465283692 -1.863571195
|
assets/pose_files/0c11dbe781b1c11c.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=a-Unpcomk5k
|
2 |
+
90023267 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.949961841 -0.054589756 0.307558835 0.363597957 0.049778115 0.998484373 0.023474237 0.122943811 -0.308374137 -0.006989930 0.951239467 -0.411649725
|
3 |
+
90190100 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.936324358 -0.058613990 0.346209586 0.384438270 0.053212658 0.998267829 0.025095066 0.136336848 -0.347080797 -0.005074390 0.937821507 -0.495378251
|
4 |
+
90356933 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.920415699 -0.061646536 0.386050045 0.392760735 0.055660341 0.998093307 0.026676189 0.148407963 -0.386958480 -0.003065505 0.922092021 -0.584840288
|
5 |
+
90523767 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.902696550 -0.065121211 0.425321281 0.393740251 0.058245987 0.997876167 0.029164905 0.157571476 -0.426317245 -0.001553800 0.904572368 -0.683591501
|
6 |
+
90690600 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.883015513 -0.069272175 0.464203626 0.383146000 0.061874375 0.997597098 0.031171000 0.171538756 -0.465247482 0.001197834 0.885179818 -0.798848920
|
7 |
+
90857433 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.863586664 -0.074230894 0.498706162 0.378236544 0.067191981 0.997224212 0.032080498 0.194804574 -0.499703199 0.005804764 0.866177261 -0.912604869
|
8 |
+
91057633 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.837207139 -0.080319084 0.540955663 0.348125228 0.072502285 0.996726155 0.035782419 0.216269091 -0.542058706 0.009263224 0.840289593 -1.067256689
|
9 |
+
91224467 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.814787984 -0.085085154 0.573481560 0.311799242 0.076606996 0.996299267 0.038975649 0.234736581 -0.574675500 0.012175811 0.818290770 -1.196836664
|
10 |
+
91391300 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.792760789 -0.089765556 0.602886796 0.270539226 0.081507996 0.995825171 0.041093048 0.259814929 -0.604058623 0.016563140 0.796767771 -1.328140863
|
11 |
+
91558133 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.771091938 -0.093814306 0.629774630 0.223948432 0.087357447 0.995320201 0.041307874 0.293608807 -0.630702674 0.023163332 0.775678813 -1.459775674
|
12 |
+
91724967 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.751535058 -0.096625380 0.652578413 0.178494515 0.089747138 0.994993508 0.043969437 0.308307880 -0.653559864 0.025522472 0.756444395 -1.587897834
|
13 |
+
91891800 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.735025227 -0.099746063 0.670662820 0.138219010 0.093737528 0.994570971 0.045186777 0.333516116 -0.671528995 0.029652854 0.740384698 -1.712296424
|
14 |
+
92092000 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.719043434 -0.101635426 0.687493145 0.093557351 0.098243877 0.994179368 0.044221871 0.373143031 -0.687985957 0.035744540 0.724843204 -1.860791364
|
15 |
+
92258833 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.709150374 -0.102171108 0.697615087 0.059169738 0.099845670 0.994025767 0.044086087 0.400782920 -0.697951734 0.038390182 0.715115070 -1.981529677
|
16 |
+
92425667 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.703360021 -0.101482928 0.703552365 0.039205180 0.098760851 0.994108558 0.044659954 0.417778776 -0.703939617 0.038071405 0.709238708 -2.106152155
|
17 |
+
92592500 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.700221658 -0.101235874 0.706711292 0.029170036 0.096752122 0.994219005 0.046557475 0.427528111 -0.707339108 0.035775267 0.705968499 -2.234683370
|
18 |
+
92759333 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.698873043 -0.100907177 0.708091974 0.024634507 0.096064955 0.994270742 0.046875048 0.444746524 -0.708765149 0.035263114 0.704562664 -2.365965080
|
19 |
+
92926167 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.698221087 -0.101446368 0.708657861 0.017460176 0.096007936 0.994235396 0.047733612 0.465147684 -0.709415078 0.034708157 0.703935742 -2.489595036
|
20 |
+
93126367 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.701540232 -0.100508377 0.705506444 0.030925799 0.096309878 0.994293392 0.045881305 0.500429136 -0.706091821 0.035759658 0.707216740 -2.635113223
|
21 |
+
93293200 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.706040561 -0.099173397 0.701192796 0.055235645 0.095376909 0.994441032 0.044612732 0.522969947 -0.701719284 0.035379205 0.711574554 -2.748741222
|
22 |
+
93460033 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.711469471 -0.101191826 0.695392966 0.092386154 0.097211346 0.994235933 0.045219962 0.550373275 -0.695960581 0.035427466 0.717205524 -2.869640023
|
23 |
+
93626867 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.715287089 -0.106710054 0.690635443 0.110879759 0.101620771 0.993650913 0.048280932 0.574557114 -0.691402614 0.035648178 0.721589625 -3.003606281
|
24 |
+
93793700 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.717362285 -0.111333445 0.687747240 0.117481763 0.104680635 0.993167043 0.051586930 0.589310163 -0.688791215 0.034987301 0.724115014 -3.119467820
|
25 |
+
93960533 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.717443645 -0.117016122 0.686718166 0.111165224 0.109845228 0.992461383 0.054354500 0.614623166 -0.687901616 0.036436427 0.724888742 -3.219243898
|
26 |
+
94160733 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.715022981 -0.122569911 0.688272297 0.080960594 0.115116455 0.991714299 0.057017289 0.647785934 -0.689558089 0.038462799 0.723208308 -3.337481340
|
assets/pose_files/0c9b371cc6225682.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=_ca03xP_KUU
|
2 |
+
212078000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981108844 0.010863926 -0.193151161 0.019480142 -0.008781361 0.999893725 0.011634931 -0.185801323 0.193257034 -0.009719004 0.981100023 -1.207220396
|
3 |
+
212245000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981206656 0.010493318 -0.192674309 0.047262620 -0.008341321 0.999893486 0.011976899 -0.196644454 0.192779467 -0.010144655 0.981189668 -1.332579514
|
4 |
+
212412000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981131375 0.009989015 -0.193084016 0.089602762 -0.007912135 0.999902308 0.011524491 -0.209028987 0.193180263 -0.009779332 0.981114566 -1.458343512
|
5 |
+
212579000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980986536 0.009889571 -0.193823576 0.142988232 -0.007621351 0.999893546 0.012444697 -0.219217661 0.193926007 -0.010730883 0.980957448 -1.565616727
|
6 |
+
212746000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980907381 0.009417370 -0.194247320 0.202069071 -0.007269385 0.999904335 0.011767862 -0.219211705 0.194339558 -0.010131124 0.980881989 -1.654996418
|
7 |
+
212913000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980841637 0.009196524 -0.194589555 0.262465567 -0.006609587 0.999880970 0.013939449 -0.224018296 0.194694594 -0.012386235 0.980785728 -1.740759996
|
8 |
+
213112000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980812550 0.009127630 -0.194739416 0.343576873 -0.006686049 0.999890625 0.013191325 -0.227157741 0.194838524 -0.011636180 0.980766296 -1.843349559
|
9 |
+
213279000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980493903 0.009053789 -0.196340859 0.419120133 -0.006299877 0.999872863 0.014646200 -0.230109231 0.196448520 -0.013123587 0.980426311 -1.921706921
|
10 |
+
213446000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980334044 0.009148465 -0.197133064 0.491943193 -0.006159810 0.999856710 0.015768444 -0.229004834 0.197249070 -0.014244041 0.980249941 -2.001160080
|
11 |
+
213613000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980392158 0.009613466 -0.196821600 0.558373534 -0.006672133 0.999855995 0.015601818 -0.224721707 0.196943253 -0.013982680 0.980315149 -2.074274069
|
12 |
+
213779000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980493963 0.009960363 -0.196296573 0.623893674 -0.006936011 0.999846518 0.016088497 -0.223079036 0.196426690 -0.014413159 0.980412602 -2.137999468
|
13 |
+
213946000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980379820 0.010693249 -0.196827397 0.699812821 -0.007542110 0.999831200 0.016752303 -0.227942951 0.196973309 -0.014939127 0.980295002 -2.197760648
|
14 |
+
214146000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.979374588 0.009917496 -0.201809332 0.814920839 -0.006856939 0.999850750 0.015859045 -0.216071799 0.201936498 -0.014148152 0.979296446 -2.259941063
|
15 |
+
214313000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.977240086 0.010313706 -0.211885542 0.923878346 -0.006808316 0.999827743 0.017266726 -0.213645187 0.212027133 -0.015431152 0.977141917 -2.298546075
|
16 |
+
214480000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.974372387 0.010545789 -0.224693686 1.023098265 -0.007014100 0.999839067 0.016510243 -0.202358523 0.224831656 -0.014511099 0.974289536 -2.336883235
|
17 |
+
214647000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.970687985 0.012000944 -0.240043446 1.114468699 -0.008058093 0.999816120 0.017400362 -0.191482059 0.240208134 -0.014956030 0.970606208 -2.373449288
|
18 |
+
214814000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.966870189 0.013693837 -0.254900873 1.198725810 -0.010257521 0.999837756 0.014805457 -0.166976597 0.255062282 -0.011700304 0.966853857 -2.418678595
|
19 |
+
214981000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.964416146 0.015124563 -0.263955861 1.261415498 -0.015491933 0.999879777 0.000689789 -0.124738174 0.263934553 0.003423943 0.964534521 -2.488291986
|
20 |
+
215181000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.961933076 0.016891202 -0.272762626 1.331672110 -0.022902885 0.999559581 -0.018870916 -0.076291319 0.272323757 0.024399608 0.961896241 -2.579417067
|
21 |
+
215348000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.960170150 0.017672766 -0.278856426 1.385267829 -0.031079723 0.998559833 -0.043730423 -0.007773330 0.277681977 0.050655428 0.959336638 -2.653662977
|
22 |
+
215515000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.959171832 0.017540060 -0.282279491 1.424533991 -0.041599691 0.995969117 -0.079466961 0.101994757 0.279747814 0.087965213 0.956035197 -2.725926173
|
23 |
+
215681000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.958615839 0.017946823 -0.284136623 1.452641041 -0.050743751 0.992801547 -0.108490512 0.201790053 0.280144215 0.118418880 0.952625930 -2.789404412
|
24 |
+
215848000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.958124936 0.017704720 -0.285802603 1.475525887 -0.056795157 0.990007460 -0.129071787 0.280252282 0.280661523 0.139899105 0.949556410 -2.857222541
|
25 |
+
216015000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.958040893 0.018267807 -0.286048800 1.493161376 -0.062177986 0.987448573 -0.145186886 0.346856721 0.279806226 0.156880915 0.947151959 -2.929304305
|
26 |
+
216215000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.959765971 0.017995594 -0.280223966 1.499055201 -0.064145394 0.985608101 -0.156403333 0.410155748 0.273376435 0.168085665 0.947107434 -3.033597428
|
assets/pose_files/0f47577ab3441480.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=in69BD2eZqg
|
2 |
+
195161633 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999976993 -0.003071866 0.006052452 0.037627942 0.003092153 0.999989629 -0.003345382 0.206876054 -0.006042113 0.003364020 0.999976099 -0.240768750
|
3 |
+
195328467 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999913514 -0.003797470 0.012590170 0.037371090 0.003835545 0.999988139 -0.003001482 0.258472399 -0.012578622 0.003049513 0.999916255 -0.264166944
|
4 |
+
195495300 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999804139 -0.004247059 0.019329911 0.038498871 0.004311955 0.999985218 -0.003316826 0.307481199 -0.019315537 0.003399526 0.999807656 -0.276803884
|
5 |
+
195662133 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999610126 -0.005408245 0.027391966 0.038009080 0.005529573 0.999975204 -0.004355530 0.361086350 -0.027367733 0.004505298 0.999615252 -0.278727233
|
6 |
+
195828967 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999336481 -0.006239665 0.035883281 0.034735125 0.006456365 0.999961615 -0.005926326 0.417233500 -0.035844926 0.006154070 0.999338388 -0.270773664
|
7 |
+
195995800 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999090433 -0.007104441 0.042045686 0.033419301 0.007331387 0.999959350 -0.005245875 0.473378445 -0.042006709 0.005549357 0.999101937 -0.261640758
|
8 |
+
196196000 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998589635 -0.007975463 0.052489758 0.032633405 0.008245680 0.999953806 -0.004933467 0.535259197 -0.052447986 0.005359322 0.998609304 -0.250263159
|
9 |
+
196362833 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998181939 -0.008554175 0.059662651 0.028043281 0.008866202 0.999948382 -0.004967103 0.576287383 -0.059617080 0.005487053 0.998206258 -0.238836996
|
10 |
+
196529667 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997807026 -0.009027892 0.065571494 0.020762443 0.009380648 0.999943137 -0.005073830 0.611177122 -0.065521955 0.005677806 0.997834980 -0.221059185
|
11 |
+
196696500 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997390270 -0.009296595 0.071597412 0.013903395 0.009683816 0.999940276 -0.005063088 0.639742116 -0.071546070 0.005743211 0.997420788 -0.192511620
|
12 |
+
196863333 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996968508 -0.009326802 0.077245183 0.007660940 0.009716687 0.999941885 -0.004673062 0.661375479 -0.077197112 0.005409463 0.997001171 -0.161790087
|
13 |
+
197030167 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996511698 -0.009557574 0.082903855 0.000657017 0.009994033 0.999938309 -0.004851241 0.672208252 -0.082852371 0.005662862 0.996545732 -0.126490956
|
14 |
+
197230367 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996102691 -0.010129508 0.087617576 -0.013035317 0.010638822 0.999929130 -0.005347892 0.673139255 -0.087557197 0.006259197 0.996139824 -0.073934910
|
15 |
+
197397200 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995961964 -0.010034073 0.089213885 -0.025143057 0.010407046 0.999938965 -0.003716475 0.666403518 -0.089171149 0.004629921 0.996005535 -0.027130940
|
16 |
+
197564033 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995849669 -0.009882330 0.090475440 -0.039230446 0.010261126 0.999940395 -0.003722523 0.652124926 -0.090433262 0.004635453 0.995891750 0.029309661
|
17 |
+
197730867 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995756924 -0.010019524 0.091475174 -0.055068664 0.010371366 0.999940515 -0.003371752 0.630272532 -0.091435947 0.004306168 0.995801628 0.101088973
|
18 |
+
197897700 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995885789 -0.009802628 0.090085521 -0.068138022 0.010283959 0.999935210 -0.004880426 0.600038118 -0.090031847 0.005786783 0.995922089 0.182818315
|
19 |
+
198064533 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995840013 -0.010080555 0.090559855 -0.077047283 0.010255665 0.999946356 -0.001468501 0.569244350 -0.090540186 0.002391143 0.995889962 0.259090585
|
20 |
+
198264733 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995863378 -0.010109165 0.090299115 -0.082291512 0.010262243 0.999946594 -0.001231105 0.534897586 -0.090281844 0.002152683 0.995913923 0.348298991
|
21 |
+
198431567 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996120393 -0.010062339 0.087423652 -0.082856431 0.010250443 0.999945998 -0.001702961 0.509342862 -0.087401800 0.002592486 0.996169746 0.427163225
|
22 |
+
198598400 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996832788 -0.009680700 0.078934617 -0.077252838 0.010075771 0.999938607 -0.004608278 0.480071534 -0.078885160 0.005389010 0.996869147 0.513721870
|
23 |
+
198765233 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997568011 -0.009045602 0.069110014 -0.060091805 0.009451369 0.999939978 -0.005546598 0.444060897 -0.069055691 0.006186293 0.997593641 0.602911453
|
24 |
+
198932067 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998626053 -0.008290987 0.051742285 -0.037270541 0.008407482 0.999962568 -0.002034174 0.410440195 -0.051723484 0.002466401 0.998658419 0.690111645
|
25 |
+
199098900 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999656141 -0.006388811 0.025431594 -0.014759559 0.006480854 0.999972761 -0.003538476 0.375793364 -0.025408294 0.003702078 0.999670327 0.777147280
|
26 |
+
199299100 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999947906 -0.003541293 -0.009570339 -0.002547566 0.003502194 0.999985456 -0.004099103 0.343015758 0.009584717 0.004065373 0.999945819 0.878377059
|
assets/pose_files/0f68374b76390082.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=-aldZQifF2U
|
2 |
+
103837067 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.792261064 -0.075338066 0.605513453 -2.753106466 0.083067641 0.996426642 0.015288832 0.122302125 -0.604501545 0.038185827 0.795688212 -1.791608923
|
3 |
+
104003900 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.772824645 -0.077280566 0.629896700 -2.856354365 0.084460691 0.996253133 0.018602582 0.115028772 -0.628974140 0.038824979 0.776456118 -1.799931844
|
4 |
+
104170733 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.752573133 -0.078496389 0.653813422 -2.957175162 0.085868694 0.996090353 0.020750597 0.112823623 -0.652886093 0.040525761 0.756371260 -1.810994932
|
5 |
+
104337567 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.730659664 -0.077806436 0.678293884 -3.062095207 0.087071396 0.995992005 0.020455774 0.121362801 -0.677166939 0.044113789 0.734505892 -1.811030009
|
6 |
+
104504400 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.706461906 -0.074765891 0.703790903 -3.177137127 0.086851373 0.996047020 0.018632174 0.129874960 -0.702401876 0.047962286 0.710162818 -1.792277939
|
7 |
+
104671233 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.681627631 -0.071119592 0.728234708 -3.294052837 0.086847432 0.996093273 0.015989548 0.143049226 -0.726526856 0.052346393 0.685141265 -1.768016440
|
8 |
+
104871433 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.649465024 -0.065721743 0.757545888 -3.442979418 0.086763002 0.996156216 0.012038323 0.166510317 -0.755425274 0.057908490 0.652670860 -1.724684703
|
9 |
+
105038267 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.621174812 -0.061671518 0.781241655 -3.558270668 0.087205477 0.996146977 0.009298084 0.180848136 -0.778804898 0.062352814 0.624159455 -1.675155675
|
10 |
+
105205100 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.591690660 -0.058407109 0.804046512 -3.660407702 0.087778911 0.996109724 0.007763143 0.186383384 -0.801371992 0.065984949 0.594515741 -1.621257762
|
11 |
+
105371933 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.561377883 -0.055783633 0.825677335 -3.752373081 0.089227341 0.995989263 0.006624432 0.194667304 -0.822735310 0.069954179 0.564103782 -1.568545872
|
12 |
+
105538767 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.531322777 -0.053783599 0.845460474 -3.836961453 0.091640897 0.995775461 0.005754844 0.205166191 -0.842198372 0.074421078 0.534006953 -1.522108893
|
13 |
+
105705600 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.501615226 -0.052979972 0.863467038 -3.914896511 0.093892507 0.995560884 0.006539768 0.201989601 -0.859980464 0.077792637 0.504362881 -1.476983336
|
14 |
+
105905800 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.466019660 -0.052434672 0.883219302 -4.004424531 0.098161966 0.995143771 0.007285428 0.209186293 -0.879312158 0.083303392 0.468903631 -1.424243874
|
15 |
+
106072633 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.435604274 -0.051984914 0.898635924 -4.083866394 0.101657487 0.994785070 0.008269622 0.213039517 -0.894379497 0.087750785 0.438617289 -1.372398599
|
16 |
+
106239467 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.404144615 -0.051714677 0.913232028 -4.163043658 0.104999557 0.994423509 0.009845568 0.210349578 -0.908648551 0.091909930 0.407320917 -1.308948274
|
17 |
+
106406300 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.372057080 -0.052390546 0.926730156 -4.232320456 0.108426183 0.994023800 0.012664661 0.202983014 -0.921855330 0.095769837 0.375514120 -1.239784641
|
18 |
+
106573133 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.338993609 -0.053159237 0.939285576 -4.297918560 0.111065105 0.993681848 0.016153777 0.191918628 -0.934209764 0.098845825 0.342755914 -1.169019518
|
19 |
+
106739967 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.305330545 -0.054462686 0.950687706 -4.358390475 0.113597691 0.993316948 0.020420864 0.175834622 -0.945446372 0.101760812 0.309476852 -1.098186456
|
20 |
+
106940167 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.264192373 -0.056177936 0.962832510 -4.426586558 0.117604628 0.992729127 0.025652671 0.163465045 -0.957272947 0.106456317 0.268878251 -1.008524756
|
21 |
+
107107000 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.228878200 -0.056077410 0.971838534 -4.485196000 0.120451130 0.992298782 0.028890507 0.159180748 -0.965974271 0.110446639 0.233870149 -0.923927626
|
22 |
+
107273833 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.192813009 -0.054965079 0.979694843 -4.547398479 0.122527294 0.991963863 0.031538919 0.153786345 -0.973555446 0.113958240 0.197998255 -0.835885482
|
23 |
+
107440667 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.155427963 -0.053089641 0.986419618 -4.614075971 0.124575593 0.991636276 0.033741303 0.151495104 -0.979960740 0.117639467 0.160741687 -0.738650735
|
24 |
+
107607500 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.117806904 -0.051662166 0.991691768 -4.672324721 0.126608506 0.991277337 0.036600262 0.144364476 -0.984932423 0.121244848 0.123320177 -0.639080225
|
25 |
+
107774333 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.080108978 -0.050879046 0.995486736 -4.716803649 0.129048899 0.990820825 0.040255725 0.133545828 -0.988397181 0.125241622 0.085939527 -0.541709066
|
26 |
+
107974533 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.034108389 -0.050325166 0.998150289 -4.758242879 0.132215530 0.990180492 0.045405328 0.118994547 -0.990633965 0.130422264 0.040427230 -0.433560831
|
assets/pose_files/2c80f9eb0d3b2bb4.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=sLIFyXD2ujI
|
2 |
+
77010267 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.991455436 0.021077231 -0.128731906 -0.119416025 -0.023393147 0.999590099 -0.016504617 0.019347615 0.128331259 0.019375037 0.991542101 -0.092957340
|
3 |
+
77143733 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.988697171 0.023564288 -0.148062930 -0.142632843 -0.026350429 0.999510169 -0.016883694 0.023384606 0.147592559 0.020594381 0.988833785 -0.115024468
|
4 |
+
77277200 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.985320270 0.026362764 -0.168668360 -0.165155176 -0.029713295 0.999407530 -0.017371174 0.028412548 0.168110475 0.022127863 0.985519767 -0.141363672
|
5 |
+
77410667 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.981402338 0.029484071 -0.189684242 -0.188834577 -0.033494804 0.999277294 -0.017972585 0.034114674 0.189017251 0.023991773 0.981680632 -0.169959835
|
6 |
+
77544133 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.976626754 0.033091862 -0.212379664 -0.212527322 -0.037670061 0.999136209 -0.017545532 0.036524990 0.211615592 0.025135791 0.977029681 -0.204014687
|
7 |
+
77677600 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.970793188 0.035301749 -0.237306431 -0.229683177 -0.040712819 0.999009848 -0.017938551 0.042000619 0.236438200 0.027076038 0.971269190 -0.236341621
|
8 |
+
77811067 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.964167893 0.038360216 -0.262504756 -0.246031807 -0.044686489 0.998835802 -0.018170038 0.047261141 0.261502147 0.029249383 0.964759588 -0.276015669
|
9 |
+
77944533 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.956261098 0.040829532 -0.289650917 -0.252766079 -0.048421524 0.998644531 -0.019089982 0.054620904 0.288478881 0.032280345 0.956941962 -0.321621308
|
10 |
+
78078000 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.946828246 0.042435788 -0.318928629 -0.251662187 -0.051583275 0.998462617 -0.020286530 0.062274582 0.317577451 0.035659242 0.947561622 -0.373008852
|
11 |
+
78211467 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.935860872 0.044850163 -0.349503726 -0.247351407 -0.055966165 0.998195410 -0.021766055 0.072942153 0.347896785 0.039930381 0.936682105 -0.431307858
|
12 |
+
78344933 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.923219025 0.046543088 -0.381445110 -0.234020172 -0.059361201 0.997996330 -0.021899769 0.078518674 0.379661530 0.042861324 0.924132049 -0.487708973
|
13 |
+
78478400 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.909880757 0.048629351 -0.412009954 -0.218247042 -0.063676558 0.997708619 -0.022863906 0.088967126 0.409954011 0.047038805 0.910892427 -0.543114491
|
14 |
+
78645233 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.891359746 0.050869841 -0.450433195 -0.185763327 -0.067926541 0.997452736 -0.021771761 0.093745158 0.448178291 0.050002839 0.892544627 -0.611223637
|
15 |
+
78778700 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.877080619 0.053094681 -0.477399796 -0.163606786 -0.072203092 0.997152746 -0.021752052 0.102191599 0.474885583 0.053548045 0.878416896 -0.664313657
|
16 |
+
78912167 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.863215029 0.055334236 -0.501794696 -0.143825544 -0.076518841 0.996831775 -0.021708660 0.111709364 0.499003649 0.057135988 0.864714324 -0.719103228
|
17 |
+
79045633 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.849143744 0.056816563 -0.525096953 -0.122334566 -0.079860382 0.996578217 -0.021311868 0.118459005 0.522089303 0.060031284 0.850775540 -0.775464728
|
18 |
+
79179100 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.835254073 0.059146367 -0.546673834 -0.101344556 -0.084243484 0.996225357 -0.020929486 0.126763936 0.543372452 0.063535146 0.837083995 -0.832841061
|
19 |
+
79312567 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.822106183 0.061935693 -0.565955281 -0.082663275 -0.088697352 0.995860636 -0.019859029 0.133423045 0.562382638 0.066524968 0.824196696 -0.894100189
|
20 |
+
79446033 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.808796465 0.064479858 -0.584543109 -0.062439027 -0.093630031 0.995411158 -0.019748187 0.145033126 0.580587387 0.070703052 0.811122298 -0.951788129
|
21 |
+
79579500 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.794901192 0.066913344 -0.603037894 -0.037377988 -0.097949244 0.995015621 -0.018705536 0.153829045 0.598780453 0.073936157 0.797493160 -1.008854626
|
22 |
+
79712967 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.781648815 0.069040783 -0.619885862 -0.013285614 -0.101820730 0.994646847 -0.017611075 0.161173621 0.615351617 0.076882906 0.784494340 -1.070102980
|
23 |
+
79846433 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.768634439 0.072034292 -0.635619521 0.012177816 -0.107280776 0.994082153 -0.017072625 0.174403322 0.630628169 0.081312358 0.771813691 -1.132424688
|
24 |
+
79979900 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.755315959 0.075072937 -0.651046753 0.040377463 -0.113015875 0.993455172 -0.016559631 0.189742153 0.645542622 0.086086370 0.758856952 -1.193296093
|
25 |
+
80113367 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.742196620 0.078075886 -0.665618777 0.069020519 -0.118120082 0.992882252 -0.015246205 0.202486741 0.659690738 0.089938626 0.746136189 -1.254564875
|
26 |
+
80280200 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.726252913 0.078639805 -0.682914674 0.104603927 -0.119984925 0.992686808 -0.013288199 0.209760187 0.676875412 0.091590062 0.730377257 -1.329527748
|
assets/pose_files/2f25826f0d0ef09a.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=t-mlAKnESzQ
|
2 |
+
167300000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991854608 -0.011446482 0.126859888 0.441665245 0.012175850 0.999913514 -0.004975420 -0.056449972 -0.126791954 0.006479521 0.991908193 -0.456202583
|
3 |
+
167467000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991945148 -0.011409644 0.126153216 0.506974565 0.012122569 0.999914587 -0.004884966 -0.069421149 -0.126086697 0.006374919 0.991998732 -0.517325825
|
4 |
+
167634000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991982698 -0.011117884 0.125883758 0.561996906 0.011760475 0.999921322 -0.004362585 -0.080740919 -0.125825346 0.005808061 0.992035389 -0.570476997
|
5 |
+
167801000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.992486775 -0.010566713 0.121894784 0.609598405 0.011126007 0.999930441 -0.003908583 -0.087745179 -0.121845007 0.005235420 0.992535353 -0.617968773
|
6 |
+
167968000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993018925 -0.010175236 0.117515638 0.655818155 0.010723241 0.999934375 -0.004031916 -0.098194076 -0.117466904 0.005263917 0.993062854 -0.668642428
|
7 |
+
168134000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993561447 -0.009874708 0.112863302 0.703081750 0.010385432 0.999938309 -0.003938090 -0.108951006 -0.112817451 0.005084869 0.993602693 -0.730919086
|
8 |
+
168335000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994062483 -0.010140099 0.108337507 0.763671544 0.010665529 0.999934018 -0.004271581 -0.104826596 -0.108287044 0.005401695 0.994104981 -0.820197463
|
9 |
+
168501000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994565487 -0.010703885 0.103560977 0.813888661 0.011249267 0.999925733 -0.004683641 -0.095187847 -0.103503153 0.005823173 0.994612098 -0.890086513
|
10 |
+
168668000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994905293 -0.010604435 0.100254886 0.865790711 0.011124965 0.999927402 -0.004634405 -0.086100908 -0.100198455 0.005726126 0.994951010 -0.962092459
|
11 |
+
168835000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995332658 -0.010553311 0.095924698 0.905775925 0.011022467 0.999929726 -0.004362300 -0.075394333 -0.095871925 0.005399267 0.995379031 -1.025694236
|
12 |
+
169002000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995705128 -0.010036361 0.092035979 0.944576676 0.010483396 0.999935448 -0.004374997 -0.058609663 -0.091986135 0.005321057 0.995746076 -1.081030198
|
13 |
+
169169000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996029556 -0.009414902 0.088523701 0.977259045 0.009879347 0.999939620 -0.004809874 -0.042104006 -0.088473074 0.005665333 0.996062458 -1.127427189
|
14 |
+
169369000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996554554 -0.009220830 0.082425818 1.013619994 0.009555685 0.999947608 -0.003668923 -0.018710063 -0.082387671 0.004443917 0.996590436 -1.175459833
|
15 |
+
169536000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996902764 -0.008823335 0.078147218 1.041872487 0.009157063 0.999950409 -0.003913174 0.011864113 -0.078108817 0.004616653 0.996934175 -1.202554477
|
16 |
+
169703000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997331142 -0.008540447 0.072509713 1.068829435 0.008805763 0.999955654 -0.003340158 0.047323405 -0.072477967 0.003969747 0.997362137 -1.214284849
|
17 |
+
169870000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997695088 -0.008219596 0.067356706 1.095289713 0.008451649 0.999959290 -0.003160893 0.090756953 -0.067327984 0.003722883 0.997723937 -1.225599061
|
18 |
+
170036000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997950792 -0.008326715 0.063442364 1.112874795 0.008502332 0.999960721 -0.002498670 0.132311648 -0.063419066 0.003032958 0.997982383 -1.233305313
|
19 |
+
170203000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998197436 -0.008063688 0.059471287 1.125840626 0.008245971 0.999962032 -0.002820280 0.178666038 -0.059446286 0.003305595 0.998226047 -1.240809047
|
20 |
+
170403000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998360872 -0.007677370 0.056714825 1.144370603 0.007830821 0.999966264 -0.002483922 0.248055953 -0.056693841 0.002923975 0.998387337 -1.246230780
|
21 |
+
170570000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998471320 -0.007715963 0.054730706 1.159189486 0.007868989 0.999965727 -0.002581036 0.310163907 -0.054708913 0.003007766 0.998497844 -1.245661417
|
22 |
+
170737000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998507679 -0.007751614 0.054058932 1.165836593 0.007918007 0.999964535 -0.002864495 0.366963293 -0.054034811 0.003288259 0.998533666 -1.241523115
|
23 |
+
170904000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998536825 -0.007817166 0.053508084 1.175042384 0.008036798 0.999960124 -0.003890704 0.423587941 -0.053475536 0.004315045 0.998559833 -1.224956309
|
24 |
+
171071000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998537302 -0.007878507 0.053490099 1.177855699 0.008138275 0.999956131 -0.004640296 0.484100754 -0.053451192 0.005068825 0.998557627 -1.202906710
|
25 |
+
171238000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998549581 -0.007872007 0.053261518 1.180678596 0.008100130 0.999958932 -0.004068544 0.548228374 -0.053227302 0.004494068 0.998572290 -1.184901744
|
26 |
+
171438000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998469293 -0.008281939 0.054685175 1.181414517 0.008542870 0.999953210 -0.004539483 0.618089736 -0.054645021 0.004999703 0.998493314 -1.159911786
|
assets/pose_files/3f79dc32d575bcdc.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=1qVpRlWxam4
|
2 |
+
87387300 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998291552 0.018666664 0.055367537 -0.431348097 -0.018963017 0.999808490 0.004831879 0.070488701 -0.055266738 -0.005873560 0.998454332 -0.848986490
|
3 |
+
87554133 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997851610 0.017319093 0.063184217 -0.464904483 -0.017837154 0.999811709 0.007644337 0.068569507 -0.063039921 -0.008754940 0.997972608 -0.876888649
|
4 |
+
87720967 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997675776 0.016262729 0.066170901 -0.486385324 -0.016915560 0.999813497 0.009317505 0.069230577 -0.066007033 -0.010415167 0.997764826 -0.912234761
|
5 |
+
87887800 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997801185 0.015721651 0.064386748 -0.496646826 -0.016416471 0.999812424 0.010276551 0.072350447 -0.064213105 -0.011310958 0.997872114 -0.952896762
|
6 |
+
88054633 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998319149 0.016226118 0.055637561 -0.489176520 -0.016823635 0.999805570 0.010287891 0.076572802 -0.055459809 -0.011206625 0.998398006 -1.004124831
|
7 |
+
88221467 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998926342 0.017387087 0.042939600 -0.475558168 -0.017787032 0.999801755 0.008949692 0.085218470 -0.042775478 -0.009703852 0.999037564 -1.053459508
|
8 |
+
88421667 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999537408 0.020246139 0.022695299 -0.447975333 -0.020412439 0.999766290 0.007119950 0.094693503 -0.022545842 -0.007579923 0.999717057 -1.119813421
|
9 |
+
88588500 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999718606 0.023644496 0.001895716 -0.414069396 -0.023654999 0.999703765 0.005723978 0.102792865 -0.001759814 -0.005767211 0.999981821 -1.180436614
|
10 |
+
88755333 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999452412 0.027961638 -0.017690983 -0.387314056 -0.027902454 0.999604225 0.003583529 0.113408687 0.017784184 -0.003087946 0.999837101 -1.226234160
|
11 |
+
88922167 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998792231 0.032815013 -0.036568113 -0.365929800 -0.032777511 0.999461353 0.001624677 0.124849793 0.036601726 -0.000424103 0.999329865 -1.267691893
|
12 |
+
89089000 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997681975 0.038422074 -0.056164715 -0.342733324 -0.038413495 0.999261141 0.001232749 0.131945819 0.056170583 0.000927592 0.998420775 -1.304181539
|
13 |
+
89255833 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.995796800 0.044428598 -0.080092981 -0.304097608 -0.044486329 0.999009430 0.001064335 0.139304626 0.080060929 0.002503182 0.996786833 -1.346184197
|
14 |
+
89456033 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.992265880 0.051900670 -0.112759680 -0.242293511 -0.051975533 0.998645782 0.002277754 0.141999546 0.112725191 0.003600606 0.993619680 -1.403491443
|
15 |
+
89622867 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.988518834 0.057282198 -0.139818728 -0.191310403 -0.057411496 0.998345733 0.003111851 0.144317113 0.139765680 0.004951079 0.990172207 -1.446433054
|
16 |
+
89789700 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.984156251 0.062501073 -0.165921792 -0.143876127 -0.062264379 0.998037636 0.006632906 0.137240925 0.166010767 0.003803201 0.986116588 -1.485275757
|
17 |
+
89956533 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.979292631 0.066839822 -0.191097870 -0.099323029 -0.066578977 0.997750700 0.007792730 0.139573975 0.191188902 0.005091738 0.981540024 -1.518326120
|
18 |
+
90123367 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.973332286 0.070821166 -0.218194127 -0.042629488 -0.070645541 0.997464299 0.008616162 0.140175484 0.218251050 0.007028054 0.975867331 -1.554681376
|
19 |
+
90290200 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.966279447 0.074934490 -0.246350974 0.028017454 -0.074612871 0.997155666 0.010653359 0.133648148 0.246448576 0.008086831 0.969122112 -1.595505702
|
20 |
+
90490400 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.957038641 0.079540767 -0.278837353 0.115624588 -0.079602204 0.996764660 0.011121323 0.132533757 0.278819799 0.011552530 0.960273921 -1.622873069
|
21 |
+
90657233 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.948620677 0.083499879 -0.305199176 0.195920884 -0.084326349 0.996382892 0.010498469 0.132694923 0.304971874 0.015777269 0.952230692 -1.640734525
|
22 |
+
90824067 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.940162480 0.087165222 -0.329388469 0.271130852 -0.089630231 0.995945156 0.007725847 0.141518901 0.328726262 0.022259612 0.944162905 -1.645387258
|
23 |
+
90990900 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.932207108 0.091082737 -0.350276887 0.339575901 -0.095441416 0.995423317 0.004838228 0.149739069 0.349114448 0.028920690 0.936633706 -1.648637528
|
24 |
+
91157733 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.925111592 0.095387921 -0.367518306 0.398473437 -0.101803973 0.994802594 0.001937620 0.156527229 0.365792990 0.035622310 0.930014253 -1.650611039
|
25 |
+
91324567 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.918998003 0.099750437 -0.381434739 0.448520864 -0.108052738 0.994145095 -0.000350893 0.159817007 0.379166484 0.041537538 0.924395680 -1.652156379
|
26 |
+
91524767 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.913031578 0.105410583 -0.394032896 0.493990424 -0.115993932 0.993245184 -0.003064641 0.163621223 0.391048223 0.048503540 0.919091225 -1.650421710
|
assets/pose_files/4a2d6753676df096.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
https://www.youtube.com/watch?v=mGFQkgadzRQ
|
2 |
+
123373000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.998857915 0.002672890 -0.047704928 -0.388737999 -0.002653247 0.999996364 0.000475094 -0.004533370 0.047706023 -0.000347978 0.998861372 0.139698036
|
3 |
+
123581000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.997534156 0.002900333 -0.070122920 -0.417036011 -0.002881077 0.999995768 0.000375740 -0.005476288 0.070123710 -0.000172784 0.997538269 0.134851393
|
4 |
+
123790000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.995756805 0.003055056 -0.091973245 -0.444572396 -0.003032017 0.999995351 0.000390221 -0.006227409 0.091974013 -0.000109701 0.995761395 0.129660844
|
5 |
+
123999000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.993462563 0.003229393 -0.114112593 -0.472377562 -0.003208589 0.999994814 0.000365978 -0.005932507 0.114113182 0.000002555 0.993467748 0.123959606
|
6 |
+
124207000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.990603268 0.003450655 -0.136723205 -0.500173495 -0.003429445 0.999994040 0.000390680 -0.006082111 0.136723727 0.000081876 0.990609229 0.117333920
|
7 |
+
124416000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.987169921 0.003684058 -0.159630775 -0.528663584 -0.003696360 0.999993145 0.000219867 -0.006000823 0.159630492 0.000373007 0.987176776 0.110039363
|
8 |
+
124666000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.982509255 0.003945273 -0.186171964 -0.561235187 -0.003999375 0.999992013 0.000084966 -0.007105507 0.186170816 0.000661092 0.982517183 0.100220962
|
9 |
+
124874000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.978331029 0.004287674 -0.207002580 -0.586641713 -0.004329238 0.999990582 0.000252201 -0.009076863 0.207001716 0.000649428 0.978340387 0.091930702
|
10 |
+
125083000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.973817825 0.004402113 -0.227287307 -0.611123286 -0.004493149 0.999989927 0.000116853 -0.009074310 0.227285519 0.000907443 0.973827720 0.083304516
|
11 |
+
125292000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.968695283 0.004357880 -0.248214558 -0.636185581 -0.004474392 0.999989986 0.000094731 -0.008011808 0.248212487 0.001018844 0.968705058 0.074442714
|
12 |
+
125500000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.962747812 0.004391920 -0.270365149 -0.662786287 -0.004570082 0.999989569 -0.000029452 -0.006714359 0.270362198 0.001263946 0.962757826 0.064526619
|
13 |
+
125709000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.955500066 0.004409539 -0.294957876 -0.691555299 -0.004778699 0.999988437 -0.000530787 -0.004279872 0.294952124 0.001916682 0.955510139 0.052776269
|
14 |
+
125959000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.946343422 0.004634521 -0.323129416 -0.724457169 -0.005231380 0.999985814 -0.000978639 -0.001732190 0.323120296 0.002616541 0.946354270 0.037519903
|
15 |
+
126167000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.939147174 0.004834646 -0.343481004 -0.749049950 -0.005533603 0.999984145 -0.001054784 -0.002170622 0.343470454 0.002891285 0.939159036 0.026149102
|
16 |
+
126376000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.931589127 0.004859613 -0.363480538 -0.772472331 -0.005669596 0.999983251 -0.001161554 -0.002324411 0.363468796 0.003142879 0.931601048 0.014526636
|
17 |
+
126584000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.923323810 0.004994850 -0.383989871 -0.796474752 -0.005933880 0.999981582 -0.001260800 -0.001656055 0.383976519 0.003442676 0.923336446 0.001805353
|
18 |
+
126793000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.914980114 0.005216786 -0.403465271 -0.819526045 -0.006272262 0.999979496 -0.001294577 -0.001858109 0.403450221 0.003715152 0.914994061 -0.010564998
|
19 |
+
127002000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.906303227 0.005258658 -0.422595352 -0.842292418 -0.006397304 0.999978721 -0.001276282 -0.001621911 0.422579646 0.003860169 0.906317592 -0.023561723
|
20 |
+
127252000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.894903898 0.005439967 -0.446225733 -0.870190326 -0.006754198 0.999976277 -0.001354740 -0.001280526 0.446207762 0.004226258 0.894919395 -0.040196739
|
21 |
+
127460000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.884573221 0.005480251 -0.466369092 -0.894082633 -0.006980692 0.999974549 -0.001489853 -0.000948027 0.466349065 0.004573463 0.884588957 -0.055396928
|
22 |
+
127669000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.873941839 0.005343055 -0.486001104 -0.917185865 -0.007038773 0.999973834 -0.001663705 -0.000687769 0.485979497 0.004874832 0.873956621 -0.070420475
|
23 |
+
127877000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.862660766 0.005402398 -0.505754173 -0.939888304 -0.007276668 0.999972045 -0.001730187 -0.000489221 0.505730629 0.005172769 0.862675905 -0.086411685
|
24 |
+
128086000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.851201892 0.005282878 -0.524811804 -0.961420775 -0.007401088 0.999970734 -0.001938020 -0.000533338 0.524786234 0.005533825 0.851216078 -0.102931062
|
25 |
+
128295000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.839382112 0.005324626 -0.543515682 -0.982849443 -0.007655066 0.999968648 -0.002025823 0.001148876 0.543487847 0.005861088 0.839396596 -0.119132721
|
26 |
+
128545000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.825496972 0.005258156 -0.564382017 -1.006530766 -0.007933038 0.999965906 -0.002286965 0.002382303 0.564350784 0.006365147 0.825510561 -0.138386240
|
assets/reference_videos/0bf152ef84195293.mp4
ADDED
Binary file (231 kB). View file
|
|
assets/reference_videos/0c11dbe781b1c11c.mp4
ADDED
Binary file (219 kB). View file
|
|
assets/reference_videos/0c9b371cc6225682.mp4
ADDED
Binary file (195 kB). View file
|
|
assets/reference_videos/0f47577ab3441480.mp4
ADDED
Binary file (161 kB). View file
|
|
assets/reference_videos/0f68374b76390082.mp4
ADDED
Binary file (299 kB). View file
|
|
assets/reference_videos/2c80f9eb0d3b2bb4.mp4
ADDED
Binary file (173 kB). View file
|
|
assets/reference_videos/2f25826f0d0ef09a.mp4
ADDED
Binary file (195 kB). View file
|
|
assets/reference_videos/3f79dc32d575bcdc.mp4
ADDED
Binary file (148 kB). View file
|
|
assets/reference_videos/4a2d6753676df096.mp4
ADDED
Binary file (229 kB). View file
|
|
cameractrl/data/dataset.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from decord import VideoReader
|
12 |
+
from torch.utils.data.dataset import Dataset
|
13 |
+
from packaging import version as pver
|
14 |
+
|
15 |
+
|
16 |
+
class RandomHorizontalFlipWithPose(nn.Module):
|
17 |
+
def __init__(self, p=0.5):
|
18 |
+
super(RandomHorizontalFlipWithPose, self).__init__()
|
19 |
+
self.p = p
|
20 |
+
|
21 |
+
def get_flip_flag(self, n_image):
|
22 |
+
return torch.rand(n_image) < self.p
|
23 |
+
|
24 |
+
def forward(self, image, flip_flag=None):
|
25 |
+
n_image = image.shape[0]
|
26 |
+
if flip_flag is not None:
|
27 |
+
assert n_image == flip_flag.shape[0]
|
28 |
+
else:
|
29 |
+
flip_flag = self.get_flip_flag(n_image)
|
30 |
+
|
31 |
+
ret_images = []
|
32 |
+
for fflag, img in zip(flip_flag, image):
|
33 |
+
if fflag:
|
34 |
+
ret_images.append(F.hflip(img))
|
35 |
+
else:
|
36 |
+
ret_images.append(img)
|
37 |
+
return torch.stack(ret_images, dim=0)
|
38 |
+
|
39 |
+
|
40 |
+
class Camera(object):
|
41 |
+
def __init__(self, entry):
|
42 |
+
fx, fy, cx, cy = entry[1:5]
|
43 |
+
self.fx = fx
|
44 |
+
self.fy = fy
|
45 |
+
self.cx = cx
|
46 |
+
self.cy = cy
|
47 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
48 |
+
w2c_mat_4x4 = np.eye(4)
|
49 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
50 |
+
self.w2c_mat = w2c_mat_4x4
|
51 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
52 |
+
|
53 |
+
|
54 |
+
def custom_meshgrid(*args):
|
55 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
56 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
57 |
+
return torch.meshgrid(*args)
|
58 |
+
else:
|
59 |
+
return torch.meshgrid(*args, indexing='ij')
|
60 |
+
|
61 |
+
|
62 |
+
def ray_condition(K, c2w, H, W, device, flip_flag=None):
|
63 |
+
# c2w: B, V, 4, 4
|
64 |
+
# K: B, V, 4
|
65 |
+
|
66 |
+
B, V = K.shape[:2]
|
67 |
+
|
68 |
+
j, i = custom_meshgrid(
|
69 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
70 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
71 |
+
)
|
72 |
+
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
|
73 |
+
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
|
74 |
+
|
75 |
+
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
|
76 |
+
if n_flip > 0:
|
77 |
+
j_flip, i_flip = custom_meshgrid(
|
78 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
79 |
+
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
|
80 |
+
)
|
81 |
+
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
|
82 |
+
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
|
83 |
+
i[:, flip_flag, ...] = i_flip
|
84 |
+
j[:, flip_flag, ...] = j_flip
|
85 |
+
|
86 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
87 |
+
|
88 |
+
zs = torch.ones_like(i) # [B, V, HxW]
|
89 |
+
xs = (i - cx) / fx * zs
|
90 |
+
ys = (j - cy) / fy * zs
|
91 |
+
zs = zs.expand_as(ys)
|
92 |
+
|
93 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
94 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
95 |
+
|
96 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
|
97 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
98 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
|
99 |
+
# c2w @ dirctions
|
100 |
+
rays_dxo = torch.linalg.cross(rays_o, rays_d) # B, V, HW, 3
|
101 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
102 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
103 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
104 |
+
return plucker
|
105 |
+
|
106 |
+
|
107 |
+
class RealEstate10K(Dataset):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
root_path,
|
111 |
+
annotation_json,
|
112 |
+
sample_stride=4,
|
113 |
+
sample_n_frames=16,
|
114 |
+
sample_size=[256, 384],
|
115 |
+
is_image=False,
|
116 |
+
):
|
117 |
+
self.root_path = root_path
|
118 |
+
self.sample_stride = sample_stride
|
119 |
+
self.sample_n_frames = sample_n_frames
|
120 |
+
self.is_image = is_image
|
121 |
+
|
122 |
+
self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
|
123 |
+
self.length = len(self.dataset)
|
124 |
+
|
125 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
126 |
+
pixel_transforms = [transforms.Resize(sample_size),
|
127 |
+
transforms.RandomHorizontalFlip(),
|
128 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
|
129 |
+
|
130 |
+
self.pixel_transforms = transforms.Compose(pixel_transforms)
|
131 |
+
|
132 |
+
def load_video_reader(self, idx):
|
133 |
+
video_dict = self.dataset[idx]
|
134 |
+
|
135 |
+
video_path = os.path.join(self.root_path, video_dict['clip_path'])
|
136 |
+
video_reader = VideoReader(video_path)
|
137 |
+
return video_reader, video_dict['caption']
|
138 |
+
|
139 |
+
def get_batch(self, idx):
|
140 |
+
video_reader, video_caption = self.load_video_reader(idx)
|
141 |
+
total_frames = len(video_reader)
|
142 |
+
|
143 |
+
if self.is_image:
|
144 |
+
frame_indice = [random.randint(0, total_frames - 1)]
|
145 |
+
else:
|
146 |
+
if isinstance(self.sample_stride, int):
|
147 |
+
current_sample_stride = self.sample_stride
|
148 |
+
else:
|
149 |
+
assert len(self.sample_stride) == 2
|
150 |
+
assert (self.sample_stride[0] >= 1) and (self.sample_stride[1] >= self.sample_stride[0])
|
151 |
+
current_sample_stride = random.randint(self.sample_stride[0], self.sample_stride[1])
|
152 |
+
|
153 |
+
cropped_length = self.sample_n_frames * current_sample_stride
|
154 |
+
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
|
155 |
+
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
|
156 |
+
|
157 |
+
assert end_frame_ind - start_frame_ind >= self.sample_n_frames
|
158 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
|
159 |
+
|
160 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
161 |
+
pixel_values = pixel_values / 255.
|
162 |
+
|
163 |
+
if self.is_image:
|
164 |
+
pixel_values = pixel_values[0]
|
165 |
+
|
166 |
+
return pixel_values, video_caption
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
return self.length
|
170 |
+
|
171 |
+
def __getitem__(self, idx):
|
172 |
+
while True:
|
173 |
+
try:
|
174 |
+
video, video_caption = self.get_batch(idx)
|
175 |
+
break
|
176 |
+
|
177 |
+
except Exception as e:
|
178 |
+
idx = random.randint(0, self.length - 1)
|
179 |
+
|
180 |
+
video = self.pixel_transforms(video)
|
181 |
+
sample = dict(pixel_values=video, caption=video_caption)
|
182 |
+
|
183 |
+
return sample
|
184 |
+
|
185 |
+
|
186 |
+
class RealEstate10KPose(Dataset):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
root_path,
|
190 |
+
annotation_json,
|
191 |
+
sample_stride=4,
|
192 |
+
minimum_sample_stride=1,
|
193 |
+
sample_n_frames=16,
|
194 |
+
relative_pose=False,
|
195 |
+
zero_t_first_frame=False,
|
196 |
+
sample_size=[256, 384],
|
197 |
+
rescale_fxy=False,
|
198 |
+
shuffle_frames=False,
|
199 |
+
use_flip=False,
|
200 |
+
return_clip_name=False,
|
201 |
+
):
|
202 |
+
self.root_path = root_path
|
203 |
+
self.relative_pose = relative_pose
|
204 |
+
self.zero_t_first_frame = zero_t_first_frame
|
205 |
+
self.sample_stride = sample_stride
|
206 |
+
self.minimum_sample_stride = minimum_sample_stride
|
207 |
+
self.sample_n_frames = sample_n_frames
|
208 |
+
self.return_clip_name = return_clip_name
|
209 |
+
|
210 |
+
self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
|
211 |
+
self.length = len(self.dataset)
|
212 |
+
|
213 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
214 |
+
self.sample_size = sample_size
|
215 |
+
if use_flip:
|
216 |
+
pixel_transforms = [transforms.Resize(sample_size),
|
217 |
+
RandomHorizontalFlipWithPose(),
|
218 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
|
219 |
+
else:
|
220 |
+
pixel_transforms = [transforms.Resize(sample_size),
|
221 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
|
222 |
+
self.rescale_fxy = rescale_fxy
|
223 |
+
self.sample_wh_ratio = sample_size[1] / sample_size[0]
|
224 |
+
|
225 |
+
self.pixel_transforms = pixel_transforms
|
226 |
+
self.shuffle_frames = shuffle_frames
|
227 |
+
self.use_flip = use_flip
|
228 |
+
|
229 |
+
def get_relative_pose(self, cam_params):
|
230 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
231 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
232 |
+
source_cam_c2w = abs_c2ws[0]
|
233 |
+
if self.zero_t_first_frame:
|
234 |
+
cam_to_origin = 0
|
235 |
+
else:
|
236 |
+
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
|
237 |
+
target_cam_c2w = np.array([
|
238 |
+
[1, 0, 0, 0],
|
239 |
+
[0, 1, 0, -cam_to_origin],
|
240 |
+
[0, 0, 1, 0],
|
241 |
+
[0, 0, 0, 1]
|
242 |
+
])
|
243 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
244 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
245 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
246 |
+
return ret_poses
|
247 |
+
|
248 |
+
def load_video_reader(self, idx):
|
249 |
+
video_dict = self.dataset[idx]
|
250 |
+
|
251 |
+
video_path = os.path.join(self.root_path, video_dict['clip_path'])
|
252 |
+
video_reader = VideoReader(video_path)
|
253 |
+
return video_dict['clip_name'], video_reader, video_dict['caption']
|
254 |
+
|
255 |
+
def load_cameras(self, idx):
|
256 |
+
video_dict = self.dataset[idx]
|
257 |
+
pose_file = os.path.join(self.root_path, video_dict['pose_file'])
|
258 |
+
with open(pose_file, 'r') as f:
|
259 |
+
poses = f.readlines()
|
260 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
261 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
262 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
263 |
+
return cam_params
|
264 |
+
|
265 |
+
def get_batch(self, idx):
|
266 |
+
clip_name, video_reader, video_caption = self.load_video_reader(idx)
|
267 |
+
cam_params = self.load_cameras(idx)
|
268 |
+
assert len(cam_params) >= self.sample_n_frames
|
269 |
+
total_frames = len(cam_params)
|
270 |
+
|
271 |
+
current_sample_stride = self.sample_stride
|
272 |
+
|
273 |
+
if total_frames < self.sample_n_frames * current_sample_stride:
|
274 |
+
maximum_sample_stride = int(total_frames // self.sample_n_frames)
|
275 |
+
current_sample_stride = random.randint(self.minimum_sample_stride, maximum_sample_stride)
|
276 |
+
|
277 |
+
cropped_length = self.sample_n_frames * current_sample_stride
|
278 |
+
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
|
279 |
+
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
|
280 |
+
|
281 |
+
assert end_frame_ind - start_frame_ind >= self.sample_n_frames
|
282 |
+
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
|
283 |
+
|
284 |
+
condition_image_ind = random.sample(list(set(range(total_frames)) - set(frame_indices.tolist())), 1)
|
285 |
+
condition_image = torch.from_numpy(video_reader.get_batch(condition_image_ind).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
286 |
+
condition_image = condition_image / 255.
|
287 |
+
|
288 |
+
if self.shuffle_frames:
|
289 |
+
perm = np.random.permutation(self.sample_n_frames)
|
290 |
+
frame_indices = frame_indices[perm]
|
291 |
+
|
292 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(frame_indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
293 |
+
pixel_values = pixel_values / 255.
|
294 |
+
|
295 |
+
cam_params = [cam_params[indice] for indice in frame_indices]
|
296 |
+
if self.rescale_fxy:
|
297 |
+
ori_h, ori_w = pixel_values.shape[-2:]
|
298 |
+
ori_wh_ratio = ori_w / ori_h
|
299 |
+
if ori_wh_ratio > self.sample_wh_ratio: # rescale fx
|
300 |
+
resized_ori_w = self.sample_size[0] * ori_wh_ratio
|
301 |
+
for cam_param in cam_params:
|
302 |
+
cam_param.fx = resized_ori_w * cam_param.fx / self.sample_size[1]
|
303 |
+
else: # rescale fy
|
304 |
+
resized_ori_h = self.sample_size[1] / ori_wh_ratio
|
305 |
+
for cam_param in cam_params:
|
306 |
+
cam_param.fy = resized_ori_h * cam_param.fy / self.sample_size[0]
|
307 |
+
intrinsics = np.asarray([[cam_param.fx * self.sample_size[1],
|
308 |
+
cam_param.fy * self.sample_size[0],
|
309 |
+
cam_param.cx * self.sample_size[1],
|
310 |
+
cam_param.cy * self.sample_size[0]]
|
311 |
+
for cam_param in cam_params], dtype=np.float32)
|
312 |
+
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
|
313 |
+
if self.relative_pose:
|
314 |
+
c2w_poses = self.get_relative_pose(cam_params)
|
315 |
+
else:
|
316 |
+
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
|
317 |
+
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
|
318 |
+
if self.use_flip:
|
319 |
+
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
|
320 |
+
else:
|
321 |
+
flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device)
|
322 |
+
plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu',
|
323 |
+
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
|
324 |
+
|
325 |
+
return pixel_values, condition_image, plucker_embedding, video_caption, flip_flag, clip_name
|
326 |
+
|
327 |
+
def __len__(self):
|
328 |
+
return self.length
|
329 |
+
|
330 |
+
def __getitem__(self, idx):
|
331 |
+
while True:
|
332 |
+
try:
|
333 |
+
video, condition_image, plucker_embedding, video_caption, flip_flag, clip_name = self.get_batch(idx)
|
334 |
+
break
|
335 |
+
|
336 |
+
except Exception as e:
|
337 |
+
idx = random.randint(0, self.length - 1)
|
338 |
+
|
339 |
+
if self.use_flip:
|
340 |
+
video = self.pixel_transforms[0](video)
|
341 |
+
video = self.pixel_transforms[1](video, flip_flag)
|
342 |
+
for transform in self.pixel_transforms[2:]:
|
343 |
+
video = transform(video)
|
344 |
+
else:
|
345 |
+
for transform in self.pixel_transforms:
|
346 |
+
video = transform(video)
|
347 |
+
for transform in self.pixel_transforms:
|
348 |
+
condition_image = transform(condition_image)
|
349 |
+
if self.return_clip_name:
|
350 |
+
sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption, clip_name=clip_name)
|
351 |
+
else:
|
352 |
+
sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption)
|
353 |
+
|
354 |
+
return sample
|
355 |
+
|
cameractrl/models/attention.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional
|
3 |
+
from diffusers.models.attention import TemporalBasicTransformerBlock, _chunked_feed_forward
|
4 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
5 |
+
|
6 |
+
|
7 |
+
@maybe_allow_in_graph
|
8 |
+
class TemporalPoseCondTransformerBlock(TemporalBasicTransformerBlock):
|
9 |
+
def forward(
|
10 |
+
self,
|
11 |
+
hidden_states: torch.FloatTensor, # [bs * num_frame, h * w, c]
|
12 |
+
num_frames: int,
|
13 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * h * w, 1, c]
|
14 |
+
pose_feature: Optional[torch.FloatTensor] = None, # [bs, c, n_frame, h, w]
|
15 |
+
) -> torch.FloatTensor:
|
16 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
17 |
+
# 0. Self-Attention
|
18 |
+
|
19 |
+
batch_frames, seq_length, channels = hidden_states.shape
|
20 |
+
batch_size = batch_frames // num_frames
|
21 |
+
|
22 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
23 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
24 |
+
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) # [bs * h * w, frame, c]
|
25 |
+
|
26 |
+
residual = hidden_states
|
27 |
+
hidden_states = self.norm_in(hidden_states)
|
28 |
+
|
29 |
+
if self._chunk_size is not None:
|
30 |
+
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
31 |
+
else:
|
32 |
+
hidden_states = self.ff_in(hidden_states)
|
33 |
+
|
34 |
+
if self.is_res:
|
35 |
+
hidden_states = hidden_states + residual
|
36 |
+
|
37 |
+
norm_hidden_states = self.norm1(hidden_states)
|
38 |
+
pose_feature = pose_feature.permute(0, 3, 4, 2, 1).reshape(batch_size * seq_length, num_frames, -1)
|
39 |
+
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None, pose_feature=pose_feature)
|
40 |
+
hidden_states = attn_output + hidden_states
|
41 |
+
|
42 |
+
# 3. Cross-Attention
|
43 |
+
if self.attn2 is not None:
|
44 |
+
norm_hidden_states = self.norm2(hidden_states)
|
45 |
+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, pose_feature=pose_feature)
|
46 |
+
hidden_states = attn_output + hidden_states
|
47 |
+
|
48 |
+
# 4. Feed-forward
|
49 |
+
norm_hidden_states = self.norm3(hidden_states)
|
50 |
+
|
51 |
+
if self._chunk_size is not None:
|
52 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
53 |
+
else:
|
54 |
+
ff_output = self.ff(norm_hidden_states)
|
55 |
+
|
56 |
+
if self.is_res:
|
57 |
+
hidden_states = ff_output + hidden_states
|
58 |
+
else:
|
59 |
+
hidden_states = ff_output
|
60 |
+
|
61 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
62 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
63 |
+
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
64 |
+
|
65 |
+
return hidden_states
|
cameractrl/models/attention_processor.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.nn.init as init
|
5 |
+
import logging
|
6 |
+
from diffusers.models.attention import Attention
|
7 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_xformers_available
|
8 |
+
from typing import Optional, Callable
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
if is_xformers_available():
|
13 |
+
import xformers
|
14 |
+
import xformers.ops
|
15 |
+
else:
|
16 |
+
xformers = None
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class AttnProcessor:
|
22 |
+
r"""
|
23 |
+
Default processor for performing attention-related computations.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __call__(
|
27 |
+
self,
|
28 |
+
attn: Attention,
|
29 |
+
hidden_states: torch.FloatTensor,
|
30 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
31 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
32 |
+
temb: Optional[torch.FloatTensor] = None,
|
33 |
+
scale: float = 1.0,
|
34 |
+
pose_feature=None
|
35 |
+
) -> torch.Tensor:
|
36 |
+
residual = hidden_states
|
37 |
+
|
38 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
39 |
+
|
40 |
+
if attn.spatial_norm is not None:
|
41 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
42 |
+
|
43 |
+
input_ndim = hidden_states.ndim
|
44 |
+
|
45 |
+
if input_ndim == 4:
|
46 |
+
batch_size, channel, height, width = hidden_states.shape
|
47 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
48 |
+
|
49 |
+
batch_size, sequence_length, _ = (
|
50 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
51 |
+
)
|
52 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
53 |
+
|
54 |
+
if attn.group_norm is not None:
|
55 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
56 |
+
|
57 |
+
query = attn.to_q(hidden_states, *args)
|
58 |
+
|
59 |
+
if encoder_hidden_states is None:
|
60 |
+
encoder_hidden_states = hidden_states
|
61 |
+
elif attn.norm_cross:
|
62 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
63 |
+
|
64 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
65 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
66 |
+
|
67 |
+
query = attn.head_to_batch_dim(query)
|
68 |
+
key = attn.head_to_batch_dim(key)
|
69 |
+
value = attn.head_to_batch_dim(value)
|
70 |
+
|
71 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
72 |
+
hidden_states = torch.bmm(attention_probs, value)
|
73 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
74 |
+
|
75 |
+
# linear proj
|
76 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
77 |
+
# dropout
|
78 |
+
hidden_states = attn.to_out[1](hidden_states)
|
79 |
+
|
80 |
+
if input_ndim == 4:
|
81 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
82 |
+
|
83 |
+
if attn.residual_connection:
|
84 |
+
hidden_states = hidden_states + residual
|
85 |
+
|
86 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
87 |
+
|
88 |
+
return hidden_states
|
89 |
+
|
90 |
+
|
91 |
+
class AttnProcessor2_0:
|
92 |
+
r"""
|
93 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self):
|
97 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
98 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
99 |
+
|
100 |
+
def __call__(
|
101 |
+
self,
|
102 |
+
attn: Attention,
|
103 |
+
hidden_states: torch.FloatTensor,
|
104 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
105 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
106 |
+
temb: Optional[torch.FloatTensor] = None,
|
107 |
+
scale: float = 1.0,
|
108 |
+
pose_feature=None
|
109 |
+
) -> torch.FloatTensor:
|
110 |
+
residual = hidden_states
|
111 |
+
|
112 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
113 |
+
|
114 |
+
if attn.spatial_norm is not None:
|
115 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
116 |
+
|
117 |
+
input_ndim = hidden_states.ndim
|
118 |
+
|
119 |
+
if input_ndim == 4:
|
120 |
+
batch_size, channel, height, width = hidden_states.shape
|
121 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
122 |
+
|
123 |
+
batch_size, sequence_length, _ = (
|
124 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
125 |
+
)
|
126 |
+
|
127 |
+
if attention_mask is not None:
|
128 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
129 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
130 |
+
# (batch, heads, source_length, target_length)
|
131 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
132 |
+
|
133 |
+
if attn.group_norm is not None:
|
134 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
135 |
+
|
136 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
137 |
+
query = attn.to_q(hidden_states, *args)
|
138 |
+
|
139 |
+
if encoder_hidden_states is None:
|
140 |
+
encoder_hidden_states = hidden_states
|
141 |
+
elif attn.norm_cross:
|
142 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
143 |
+
|
144 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
145 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
146 |
+
|
147 |
+
inner_dim = key.shape[-1]
|
148 |
+
head_dim = inner_dim // attn.heads
|
149 |
+
|
150 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
151 |
+
|
152 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
153 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
154 |
+
|
155 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
156 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
157 |
+
hidden_states = F.scaled_dot_product_attention(
|
158 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
159 |
+
)
|
160 |
+
|
161 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
162 |
+
hidden_states = hidden_states.to(query.dtype)
|
163 |
+
|
164 |
+
# linear proj
|
165 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
166 |
+
# dropout
|
167 |
+
hidden_states = attn.to_out[1](hidden_states)
|
168 |
+
|
169 |
+
if input_ndim == 4:
|
170 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
171 |
+
|
172 |
+
if attn.residual_connection:
|
173 |
+
hidden_states = hidden_states + residual
|
174 |
+
|
175 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
176 |
+
|
177 |
+
return hidden_states
|
178 |
+
|
179 |
+
|
180 |
+
class XFormersAttnProcessor:
|
181 |
+
r"""
|
182 |
+
Processor for implementing memory efficient attention using xFormers.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
186 |
+
The base
|
187 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
188 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
189 |
+
operator.
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
193 |
+
self.attention_op = attention_op
|
194 |
+
|
195 |
+
def __call__(
|
196 |
+
self,
|
197 |
+
attn: Attention,
|
198 |
+
hidden_states: torch.FloatTensor,
|
199 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
200 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
201 |
+
temb: Optional[torch.FloatTensor] = None,
|
202 |
+
scale: float = 1.0,
|
203 |
+
pose_feature=None
|
204 |
+
) -> torch.FloatTensor:
|
205 |
+
residual = hidden_states
|
206 |
+
|
207 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
208 |
+
|
209 |
+
if attn.spatial_norm is not None:
|
210 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
211 |
+
|
212 |
+
input_ndim = hidden_states.ndim
|
213 |
+
|
214 |
+
if input_ndim == 4:
|
215 |
+
batch_size, channel, height, width = hidden_states.shape
|
216 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
217 |
+
|
218 |
+
batch_size, key_tokens, _ = (
|
219 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
220 |
+
)
|
221 |
+
|
222 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
223 |
+
if attention_mask is not None:
|
224 |
+
# expand our mask's singleton query_tokens dimension:
|
225 |
+
# [batch*heads, 1, key_tokens] ->
|
226 |
+
# [batch*heads, query_tokens, key_tokens]
|
227 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
228 |
+
# [batch*heads, query_tokens, key_tokens]
|
229 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
230 |
+
_, query_tokens, _ = hidden_states.shape
|
231 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
232 |
+
|
233 |
+
if attn.group_norm is not None:
|
234 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
235 |
+
|
236 |
+
query = attn.to_q(hidden_states, *args)
|
237 |
+
|
238 |
+
if encoder_hidden_states is None:
|
239 |
+
encoder_hidden_states = hidden_states
|
240 |
+
elif attn.norm_cross:
|
241 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
242 |
+
|
243 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
244 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
245 |
+
|
246 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
247 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
248 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
249 |
+
|
250 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
251 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
252 |
+
)
|
253 |
+
hidden_states = hidden_states.to(query.dtype)
|
254 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
255 |
+
|
256 |
+
# linear proj
|
257 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
258 |
+
# dropout
|
259 |
+
hidden_states = attn.to_out[1](hidden_states)
|
260 |
+
|
261 |
+
if input_ndim == 4:
|
262 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
263 |
+
|
264 |
+
if attn.residual_connection:
|
265 |
+
hidden_states = hidden_states + residual
|
266 |
+
|
267 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
268 |
+
|
269 |
+
return hidden_states
|
270 |
+
|
271 |
+
|
272 |
+
class PoseAdaptorAttnProcessor(nn.Module):
|
273 |
+
def __init__(self,
|
274 |
+
hidden_size, # dimension of hidden state
|
275 |
+
pose_feature_dim=None, # dimension of the pose feature
|
276 |
+
cross_attention_dim=None, # dimension of the text embedding
|
277 |
+
query_condition=False,
|
278 |
+
key_value_condition=False,
|
279 |
+
scale=1.0):
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
self.hidden_size = hidden_size
|
283 |
+
self.pose_feature_dim = pose_feature_dim
|
284 |
+
self.cross_attention_dim = cross_attention_dim
|
285 |
+
self.scale = scale
|
286 |
+
self.query_condition = query_condition
|
287 |
+
self.key_value_condition = key_value_condition
|
288 |
+
assert hidden_size == pose_feature_dim
|
289 |
+
if self.query_condition and self.key_value_condition:
|
290 |
+
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
|
291 |
+
init.zeros_(self.qkv_merge.weight)
|
292 |
+
init.zeros_(self.qkv_merge.bias)
|
293 |
+
elif self.query_condition:
|
294 |
+
self.q_merge = nn.Linear(hidden_size, hidden_size)
|
295 |
+
init.zeros_(self.q_merge.weight)
|
296 |
+
init.zeros_(self.q_merge.bias)
|
297 |
+
else:
|
298 |
+
self.kv_merge = nn.Linear(hidden_size, hidden_size)
|
299 |
+
init.zeros_(self.kv_merge.weight)
|
300 |
+
init.zeros_(self.kv_merge.bias)
|
301 |
+
|
302 |
+
def forward(self,
|
303 |
+
attn,
|
304 |
+
hidden_states,
|
305 |
+
pose_feature,
|
306 |
+
encoder_hidden_states=None,
|
307 |
+
attention_mask=None,
|
308 |
+
temb=None,
|
309 |
+
scale=None,):
|
310 |
+
assert pose_feature is not None
|
311 |
+
pose_embedding_scale = (scale or self.scale)
|
312 |
+
|
313 |
+
residual = hidden_states
|
314 |
+
if attn.spatial_norm is not None:
|
315 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
316 |
+
|
317 |
+
assert hidden_states.ndim == 3 and pose_feature.ndim == 3
|
318 |
+
|
319 |
+
if self.query_condition and self.key_value_condition:
|
320 |
+
assert encoder_hidden_states is None
|
321 |
+
|
322 |
+
if encoder_hidden_states is None:
|
323 |
+
encoder_hidden_states = hidden_states
|
324 |
+
|
325 |
+
assert encoder_hidden_states.ndim == 3
|
326 |
+
|
327 |
+
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
|
328 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
|
329 |
+
|
330 |
+
if attn.group_norm is not None:
|
331 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
332 |
+
|
333 |
+
if attn.norm_cross:
|
334 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
335 |
+
|
336 |
+
if self.query_condition and self.key_value_condition: # only self attention
|
337 |
+
query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
|
338 |
+
key_value_hidden_state = query_hidden_state
|
339 |
+
elif self.query_condition:
|
340 |
+
query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
|
341 |
+
key_value_hidden_state = encoder_hidden_states
|
342 |
+
else:
|
343 |
+
key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
|
344 |
+
query_hidden_state = hidden_states
|
345 |
+
|
346 |
+
# original attention
|
347 |
+
query = attn.to_q(query_hidden_state)
|
348 |
+
key = attn.to_k(key_value_hidden_state)
|
349 |
+
value = attn.to_v(key_value_hidden_state)
|
350 |
+
|
351 |
+
query = attn.head_to_batch_dim(query)
|
352 |
+
key = attn.head_to_batch_dim(key)
|
353 |
+
value = attn.head_to_batch_dim(value)
|
354 |
+
|
355 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
356 |
+
hidden_states = torch.bmm(attention_probs, value)
|
357 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
358 |
+
|
359 |
+
# linear proj
|
360 |
+
hidden_states = attn.to_out[0](hidden_states)
|
361 |
+
# dropout
|
362 |
+
hidden_states = attn.to_out[1](hidden_states)
|
363 |
+
|
364 |
+
if attn.residual_connection:
|
365 |
+
hidden_states = hidden_states + residual
|
366 |
+
|
367 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
368 |
+
|
369 |
+
return hidden_states
|
370 |
+
|
371 |
+
|
372 |
+
class PoseAdaptorAttnProcessor2_0(nn.Module):
|
373 |
+
def __init__(self,
|
374 |
+
hidden_size, # dimension of hidden state
|
375 |
+
pose_feature_dim=None, # dimension of the pose feature
|
376 |
+
cross_attention_dim=None, # dimension of the text embedding
|
377 |
+
query_condition=False,
|
378 |
+
key_value_condition=False,
|
379 |
+
scale=1.0):
|
380 |
+
super().__init__()
|
381 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
382 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
383 |
+
|
384 |
+
self.hidden_size = hidden_size
|
385 |
+
self.pose_feature_dim = pose_feature_dim
|
386 |
+
self.cross_attention_dim = cross_attention_dim
|
387 |
+
self.scale = scale
|
388 |
+
self.query_condition = query_condition
|
389 |
+
self.key_value_condition = key_value_condition
|
390 |
+
assert hidden_size == pose_feature_dim
|
391 |
+
if self.query_condition and self.key_value_condition:
|
392 |
+
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
|
393 |
+
init.zeros_(self.qkv_merge.weight)
|
394 |
+
init.zeros_(self.qkv_merge.bias)
|
395 |
+
elif self.query_condition:
|
396 |
+
self.q_merge = nn.Linear(hidden_size, hidden_size)
|
397 |
+
init.zeros_(self.q_merge.weight)
|
398 |
+
init.zeros_(self.q_merge.bias)
|
399 |
+
else:
|
400 |
+
self.kv_merge = nn.Linear(hidden_size, hidden_size)
|
401 |
+
init.zeros_(self.kv_merge.weight)
|
402 |
+
init.zeros_(self.kv_merge.bias)
|
403 |
+
|
404 |
+
def forward(self,
|
405 |
+
attn,
|
406 |
+
hidden_states,
|
407 |
+
pose_feature,
|
408 |
+
encoder_hidden_states=None,
|
409 |
+
attention_mask=None,
|
410 |
+
temb=None,
|
411 |
+
scale=None,):
|
412 |
+
assert pose_feature is not None
|
413 |
+
pose_embedding_scale = (scale or self.scale)
|
414 |
+
|
415 |
+
residual = hidden_states
|
416 |
+
if attn.spatial_norm is not None:
|
417 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
418 |
+
|
419 |
+
assert hidden_states.ndim == 3 and pose_feature.ndim == 3
|
420 |
+
|
421 |
+
if self.query_condition and self.key_value_condition:
|
422 |
+
assert encoder_hidden_states is None
|
423 |
+
|
424 |
+
if encoder_hidden_states is None:
|
425 |
+
encoder_hidden_states = hidden_states
|
426 |
+
|
427 |
+
assert encoder_hidden_states.ndim == 3
|
428 |
+
|
429 |
+
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
|
430 |
+
if attention_mask is not None:
|
431 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
|
432 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
433 |
+
# (batch, heads, source_length, target_length)
|
434 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
435 |
+
|
436 |
+
if attn.group_norm is not None:
|
437 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
438 |
+
|
439 |
+
if attn.norm_cross:
|
440 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
441 |
+
|
442 |
+
if self.query_condition and self.key_value_condition: # only self attention
|
443 |
+
query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
|
444 |
+
key_value_hidden_state = query_hidden_state
|
445 |
+
elif self.query_condition:
|
446 |
+
query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
|
447 |
+
key_value_hidden_state = encoder_hidden_states
|
448 |
+
else:
|
449 |
+
key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
|
450 |
+
query_hidden_state = hidden_states
|
451 |
+
|
452 |
+
# original attention
|
453 |
+
query = attn.to_q(query_hidden_state)
|
454 |
+
key = attn.to_k(key_value_hidden_state)
|
455 |
+
value = attn.to_v(key_value_hidden_state)
|
456 |
+
|
457 |
+
inner_dim = key.shape[-1]
|
458 |
+
head_dim = inner_dim // attn.heads
|
459 |
+
|
460 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [bs, seq_len, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
|
461 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
462 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
463 |
+
|
464 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) # [bs, nhead, seq_len, head_dim]
|
465 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # [bs, seq_len, dim]
|
466 |
+
hidden_states = hidden_states.to(query.dtype)
|
467 |
+
|
468 |
+
# linear proj
|
469 |
+
hidden_states = attn.to_out[0](hidden_states)
|
470 |
+
# dropout
|
471 |
+
hidden_states = attn.to_out[1](hidden_states)
|
472 |
+
|
473 |
+
if attn.residual_connection:
|
474 |
+
hidden_states = hidden_states + residual
|
475 |
+
|
476 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
477 |
+
|
478 |
+
return hidden_states
|
479 |
+
|
480 |
+
|
481 |
+
class PoseAdaptorXFormersAttnProcessor(nn.Module):
|
482 |
+
def __init__(self,
|
483 |
+
hidden_size, # dimension of hidden state
|
484 |
+
pose_feature_dim=None, # dimension of the pose feature
|
485 |
+
cross_attention_dim=None, # dimension of the text embedding
|
486 |
+
query_condition=False,
|
487 |
+
key_value_condition=False,
|
488 |
+
scale=1.0,
|
489 |
+
attention_op: Optional[Callable] = None):
|
490 |
+
super().__init__()
|
491 |
+
|
492 |
+
self.hidden_size = hidden_size
|
493 |
+
self.pose_feature_dim = pose_feature_dim
|
494 |
+
self.cross_attention_dim = cross_attention_dim
|
495 |
+
self.scale = scale
|
496 |
+
self.query_condition = query_condition
|
497 |
+
self.key_value_condition = key_value_condition
|
498 |
+
self.attention_op = attention_op
|
499 |
+
assert hidden_size == pose_feature_dim
|
500 |
+
if self.query_condition and self.key_value_condition:
|
501 |
+
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
|
502 |
+
init.zeros_(self.qkv_merge.weight)
|
503 |
+
init.zeros_(self.qkv_merge.bias)
|
504 |
+
elif self.query_condition:
|
505 |
+
self.q_merge = nn.Linear(hidden_size, hidden_size)
|
506 |
+
init.zeros_(self.q_merge.weight)
|
507 |
+
init.zeros_(self.q_merge.bias)
|
508 |
+
else:
|
509 |
+
self.kv_merge = nn.Linear(hidden_size, hidden_size)
|
510 |
+
init.zeros_(self.kv_merge.weight)
|
511 |
+
init.zeros_(self.kv_merge.bias)
|
512 |
+
|
513 |
+
def forward(self,
|
514 |
+
attn,
|
515 |
+
hidden_states,
|
516 |
+
pose_feature,
|
517 |
+
encoder_hidden_states=None,
|
518 |
+
attention_mask=None,
|
519 |
+
temb=None,
|
520 |
+
scale=None,):
|
521 |
+
assert pose_feature is not None
|
522 |
+
pose_embedding_scale = (scale or self.scale)
|
523 |
+
|
524 |
+
residual = hidden_states
|
525 |
+
if attn.spatial_norm is not None:
|
526 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
527 |
+
|
528 |
+
assert hidden_states.ndim == 3 and pose_feature.ndim == 3
|
529 |
+
|
530 |
+
if self.query_condition and self.key_value_condition:
|
531 |
+
assert encoder_hidden_states is None
|
532 |
+
|
533 |
+
if encoder_hidden_states is None:
|
534 |
+
encoder_hidden_states = hidden_states
|
535 |
+
|
536 |
+
assert encoder_hidden_states.ndim == 3
|
537 |
+
|
538 |
+
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
|
539 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
|
540 |
+
if attention_mask is not None:
|
541 |
+
# expand our mask's singleton query_tokens dimension:
|
542 |
+
# [batch*heads, 1, key_tokens] ->
|
543 |
+
# [batch*heads, query_tokens, key_tokens]
|
544 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
545 |
+
# [batch*heads, query_tokens, key_tokens]
|
546 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
547 |
+
_, query_tokens, _ = hidden_states.shape
|
548 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
549 |
+
|
550 |
+
if attn.group_norm is not None:
|
551 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
552 |
+
|
553 |
+
if attn.norm_cross:
|
554 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
555 |
+
|
556 |
+
if self.query_condition and self.key_value_condition: # only self attention
|
557 |
+
query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
|
558 |
+
key_value_hidden_state = query_hidden_state
|
559 |
+
elif self.query_condition:
|
560 |
+
query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
|
561 |
+
key_value_hidden_state = encoder_hidden_states
|
562 |
+
else:
|
563 |
+
key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
|
564 |
+
query_hidden_state = hidden_states
|
565 |
+
|
566 |
+
# original attention
|
567 |
+
query = attn.to_q(query_hidden_state)
|
568 |
+
key = attn.to_k(key_value_hidden_state)
|
569 |
+
value = attn.to_v(key_value_hidden_state)
|
570 |
+
|
571 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
572 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
573 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
574 |
+
|
575 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
576 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
577 |
+
)
|
578 |
+
hidden_states = hidden_states.to(query.dtype)
|
579 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
580 |
+
|
581 |
+
# linear proj
|
582 |
+
hidden_states = attn.to_out[0](hidden_states)
|
583 |
+
# dropout
|
584 |
+
hidden_states = attn.to_out[1](hidden_states)
|
585 |
+
|
586 |
+
if attn.residual_connection:
|
587 |
+
hidden_states = hidden_states + residual
|
588 |
+
|
589 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
590 |
+
|
591 |
+
return hidden_states
|
cameractrl/models/motion_module.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from diffusers.utils import BaseOutput
|
8 |
+
from diffusers.models.attention_processor import Attention
|
9 |
+
from diffusers.models.attention import FeedForward
|
10 |
+
|
11 |
+
from typing import Dict, Any
|
12 |
+
from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor
|
13 |
+
|
14 |
+
from einops import rearrange
|
15 |
+
import math
|
16 |
+
|
17 |
+
|
18 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
19 |
+
def forward(self, x):
|
20 |
+
# return super().forward(x)
|
21 |
+
|
22 |
+
video_length = x.shape[2]
|
23 |
+
|
24 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
25 |
+
x = super().forward(x)
|
26 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
27 |
+
|
28 |
+
return x
|
29 |
+
|
30 |
+
def zero_module(module):
|
31 |
+
# Zero out the parameters of a module and return it.
|
32 |
+
for p in module.parameters():
|
33 |
+
p.detach().zero_()
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
39 |
+
sample: torch.FloatTensor
|
40 |
+
|
41 |
+
|
42 |
+
def get_motion_module(
|
43 |
+
in_channels,
|
44 |
+
motion_module_type: str,
|
45 |
+
motion_module_kwargs: dict
|
46 |
+
):
|
47 |
+
if motion_module_type == "Vanilla":
|
48 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs)
|
49 |
+
else:
|
50 |
+
raise ValueError
|
51 |
+
|
52 |
+
|
53 |
+
class VanillaTemporalModule(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
in_channels,
|
57 |
+
num_attention_heads=8,
|
58 |
+
num_transformer_block=2,
|
59 |
+
attention_block_types=("Temporal_Self",),
|
60 |
+
temporal_position_encoding=True,
|
61 |
+
temporal_position_encoding_max_len=32,
|
62 |
+
temporal_attention_dim_div=1,
|
63 |
+
cross_attention_dim=320,
|
64 |
+
zero_initialize=True,
|
65 |
+
encoder_hidden_states_query=(False, False),
|
66 |
+
attention_activation_scale=1.0,
|
67 |
+
attention_processor_kwargs: Dict = {},
|
68 |
+
causal_temporal_attention=False,
|
69 |
+
causal_temporal_attention_mask_type="",
|
70 |
+
rescale_output_factor=1.0
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
75 |
+
in_channels=in_channels,
|
76 |
+
num_attention_heads=num_attention_heads,
|
77 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
78 |
+
num_layers=num_transformer_block,
|
79 |
+
attention_block_types=attention_block_types,
|
80 |
+
cross_attention_dim=cross_attention_dim,
|
81 |
+
temporal_position_encoding=temporal_position_encoding,
|
82 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
83 |
+
encoder_hidden_states_query=encoder_hidden_states_query,
|
84 |
+
attention_activation_scale=attention_activation_scale,
|
85 |
+
attention_processor_kwargs=attention_processor_kwargs,
|
86 |
+
causal_temporal_attention=causal_temporal_attention,
|
87 |
+
causal_temporal_attention_mask_type=causal_temporal_attention_mask_type,
|
88 |
+
rescale_output_factor=rescale_output_factor
|
89 |
+
)
|
90 |
+
|
91 |
+
if zero_initialize:
|
92 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
93 |
+
|
94 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
|
95 |
+
cross_attention_kwargs: Dict[str, Any] = {}):
|
96 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cross_attention_kwargs=cross_attention_kwargs)
|
97 |
+
|
98 |
+
output = hidden_states
|
99 |
+
return output
|
100 |
+
|
101 |
+
|
102 |
+
class TemporalTransformer3DModel(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
in_channels,
|
106 |
+
num_attention_heads,
|
107 |
+
attention_head_dim,
|
108 |
+
num_layers,
|
109 |
+
attention_block_types=("Temporal_Self", "Temporal_Self",),
|
110 |
+
dropout=0.0,
|
111 |
+
norm_num_groups=32,
|
112 |
+
cross_attention_dim=320,
|
113 |
+
activation_fn="geglu",
|
114 |
+
attention_bias=False,
|
115 |
+
upcast_attention=False,
|
116 |
+
temporal_position_encoding=False,
|
117 |
+
temporal_position_encoding_max_len=32,
|
118 |
+
encoder_hidden_states_query=(False, False),
|
119 |
+
attention_activation_scale=1.0,
|
120 |
+
attention_processor_kwargs: Dict = {},
|
121 |
+
|
122 |
+
causal_temporal_attention=None,
|
123 |
+
causal_temporal_attention_mask_type="",
|
124 |
+
rescale_output_factor=1.0
|
125 |
+
):
|
126 |
+
super().__init__()
|
127 |
+
assert causal_temporal_attention is not None
|
128 |
+
self.causal_temporal_attention = causal_temporal_attention
|
129 |
+
|
130 |
+
assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "")
|
131 |
+
self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type
|
132 |
+
self.causal_temporal_attention_mask = None
|
133 |
+
|
134 |
+
inner_dim = num_attention_heads * attention_head_dim
|
135 |
+
|
136 |
+
self.norm = InflatedGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
137 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
138 |
+
|
139 |
+
self.transformer_blocks = nn.ModuleList(
|
140 |
+
[
|
141 |
+
TemporalTransformerBlock(
|
142 |
+
dim=inner_dim,
|
143 |
+
num_attention_heads=num_attention_heads,
|
144 |
+
attention_head_dim=attention_head_dim,
|
145 |
+
attention_block_types=attention_block_types,
|
146 |
+
dropout=dropout,
|
147 |
+
norm_num_groups=norm_num_groups,
|
148 |
+
cross_attention_dim=cross_attention_dim,
|
149 |
+
activation_fn=activation_fn,
|
150 |
+
attention_bias=attention_bias,
|
151 |
+
upcast_attention=upcast_attention,
|
152 |
+
temporal_position_encoding=temporal_position_encoding,
|
153 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
154 |
+
encoder_hidden_states_query=encoder_hidden_states_query,
|
155 |
+
attention_activation_scale=attention_activation_scale,
|
156 |
+
attention_processor_kwargs=attention_processor_kwargs,
|
157 |
+
rescale_output_factor=rescale_output_factor,
|
158 |
+
)
|
159 |
+
for d in range(num_layers)
|
160 |
+
]
|
161 |
+
)
|
162 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
163 |
+
|
164 |
+
def get_causal_temporal_attention_mask(self, hidden_states):
|
165 |
+
batch_size, sequence_length, dim = hidden_states.shape
|
166 |
+
|
167 |
+
if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (
|
168 |
+
batch_size, sequence_length, sequence_length):
|
169 |
+
if self.causal_temporal_attention_mask_type == "causal":
|
170 |
+
# 1. vanilla causal mask
|
171 |
+
mask = torch.tril(torch.ones(sequence_length, sequence_length))
|
172 |
+
|
173 |
+
elif self.causal_temporal_attention_mask_type == "2-seq":
|
174 |
+
# 2. 2-seq
|
175 |
+
mask = torch.zeros(sequence_length, sequence_length)
|
176 |
+
mask[:sequence_length // 2, :sequence_length // 2] = 1
|
177 |
+
mask[-sequence_length // 2:, -sequence_length // 2:] = 1
|
178 |
+
|
179 |
+
elif self.causal_temporal_attention_mask_type == "0-prev":
|
180 |
+
# attn to the previous frame
|
181 |
+
indices = torch.arange(sequence_length)
|
182 |
+
indices_prev = indices - 1
|
183 |
+
indices_prev[0] = 0
|
184 |
+
mask = torch.zeros(sequence_length, sequence_length)
|
185 |
+
mask[:, 0] = 1.
|
186 |
+
mask[indices, indices_prev] = 1.
|
187 |
+
|
188 |
+
elif self.causal_temporal_attention_mask_type == "0":
|
189 |
+
# only attn to first frame
|
190 |
+
mask = torch.zeros(sequence_length, sequence_length)
|
191 |
+
mask[:, 0] = 1
|
192 |
+
|
193 |
+
elif self.causal_temporal_attention_mask_type == "wo-self":
|
194 |
+
indices = torch.arange(sequence_length)
|
195 |
+
mask = torch.ones(sequence_length, sequence_length)
|
196 |
+
mask[indices, indices] = 0
|
197 |
+
|
198 |
+
elif self.causal_temporal_attention_mask_type == "circle":
|
199 |
+
indices = torch.arange(sequence_length)
|
200 |
+
indices_prev = indices - 1
|
201 |
+
indices_prev[0] = 0
|
202 |
+
|
203 |
+
mask = torch.eye(sequence_length)
|
204 |
+
mask[indices, indices_prev] = 1
|
205 |
+
mask[0, -1] = 1
|
206 |
+
|
207 |
+
else:
|
208 |
+
raise ValueError
|
209 |
+
|
210 |
+
# generate attention mask fron binary values
|
211 |
+
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
212 |
+
mask = mask.unsqueeze(0)
|
213 |
+
mask = mask.repeat(batch_size, 1, 1)
|
214 |
+
|
215 |
+
self.causal_temporal_attention_mask = mask.to(hidden_states.device)
|
216 |
+
|
217 |
+
return self.causal_temporal_attention_mask
|
218 |
+
|
219 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
220 |
+
cross_attention_kwargs: Dict[str, Any] = {},):
|
221 |
+
residual = hidden_states
|
222 |
+
|
223 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
224 |
+
height, width = hidden_states.shape[-2:]
|
225 |
+
|
226 |
+
hidden_states = self.norm(hidden_states)
|
227 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c")
|
228 |
+
hidden_states = self.proj_in(hidden_states)
|
229 |
+
|
230 |
+
attention_mask = self.get_causal_temporal_attention_mask(
|
231 |
+
hidden_states) if self.causal_temporal_attention else attention_mask
|
232 |
+
|
233 |
+
# Transformer Blocks
|
234 |
+
for block in self.transformer_blocks:
|
235 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states,
|
236 |
+
attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs)
|
237 |
+
hidden_states = self.proj_out(hidden_states)
|
238 |
+
|
239 |
+
hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width)
|
240 |
+
|
241 |
+
output = hidden_states + residual
|
242 |
+
|
243 |
+
return output
|
244 |
+
|
245 |
+
|
246 |
+
class TemporalTransformerBlock(nn.Module):
|
247 |
+
def __init__(
|
248 |
+
self,
|
249 |
+
dim,
|
250 |
+
num_attention_heads,
|
251 |
+
attention_head_dim,
|
252 |
+
attention_block_types=("Temporal_Self", "Temporal_Self",),
|
253 |
+
dropout=0.0,
|
254 |
+
norm_num_groups=32,
|
255 |
+
cross_attention_dim=768,
|
256 |
+
activation_fn="geglu",
|
257 |
+
attention_bias=False,
|
258 |
+
upcast_attention=False,
|
259 |
+
temporal_position_encoding=False,
|
260 |
+
temporal_position_encoding_max_len=32,
|
261 |
+
encoder_hidden_states_query=(False, False),
|
262 |
+
attention_activation_scale=1.0,
|
263 |
+
attention_processor_kwargs: Dict = {},
|
264 |
+
rescale_output_factor=1.0
|
265 |
+
):
|
266 |
+
super().__init__()
|
267 |
+
|
268 |
+
attention_blocks = []
|
269 |
+
norms = []
|
270 |
+
self.attention_block_types = attention_block_types
|
271 |
+
|
272 |
+
for block_idx, block_name in enumerate(attention_block_types):
|
273 |
+
attention_blocks.append(
|
274 |
+
TemporalSelfAttention(
|
275 |
+
attention_mode=block_name,
|
276 |
+
cross_attention_dim=cross_attention_dim if block_name in ['Temporal_Cross', 'Temporal_Pose_Adaptor'] else None,
|
277 |
+
query_dim=dim,
|
278 |
+
heads=num_attention_heads,
|
279 |
+
dim_head=attention_head_dim,
|
280 |
+
dropout=dropout,
|
281 |
+
bias=attention_bias,
|
282 |
+
upcast_attention=upcast_attention,
|
283 |
+
temporal_position_encoding=temporal_position_encoding,
|
284 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
285 |
+
rescale_output_factor=rescale_output_factor,
|
286 |
+
)
|
287 |
+
)
|
288 |
+
norms.append(nn.LayerNorm(dim))
|
289 |
+
|
290 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
291 |
+
self.norms = nn.ModuleList(norms)
|
292 |
+
|
293 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
294 |
+
self.ff_norm = nn.LayerNorm(dim)
|
295 |
+
|
296 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs: Dict[str, Any] = {}):
|
297 |
+
for attention_block, norm, attention_block_type in zip(self.attention_blocks, self.norms, self.attention_block_types):
|
298 |
+
norm_hidden_states = norm(hidden_states)
|
299 |
+
hidden_states = attention_block(
|
300 |
+
norm_hidden_states,
|
301 |
+
encoder_hidden_states=norm_hidden_states if attention_block_type == 'Temporal_Self' else encoder_hidden_states,
|
302 |
+
attention_mask=attention_mask,
|
303 |
+
**cross_attention_kwargs
|
304 |
+
) + hidden_states
|
305 |
+
|
306 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
307 |
+
|
308 |
+
output = hidden_states
|
309 |
+
return output
|
310 |
+
|
311 |
+
|
312 |
+
class PositionalEncoding(nn.Module):
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
d_model,
|
316 |
+
dropout=0.,
|
317 |
+
max_len=32,
|
318 |
+
):
|
319 |
+
super().__init__()
|
320 |
+
self.dropout = nn.Dropout(p=dropout)
|
321 |
+
position = torch.arange(max_len).unsqueeze(1)
|
322 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
323 |
+
pe = torch.zeros(1, max_len, d_model)
|
324 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
325 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
326 |
+
self.register_buffer('pe', pe)
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
x = x + self.pe[:, :x.size(1)]
|
330 |
+
return self.dropout(x)
|
331 |
+
|
332 |
+
|
333 |
+
class TemporalSelfAttention(Attention):
|
334 |
+
def __init__(
|
335 |
+
self,
|
336 |
+
attention_mode=None,
|
337 |
+
temporal_position_encoding=False,
|
338 |
+
temporal_position_encoding_max_len=32,
|
339 |
+
rescale_output_factor=1.0,
|
340 |
+
*args, **kwargs
|
341 |
+
):
|
342 |
+
super().__init__(*args, **kwargs)
|
343 |
+
assert attention_mode == "Temporal_Self"
|
344 |
+
|
345 |
+
self.pos_encoder = PositionalEncoding(
|
346 |
+
kwargs["query_dim"],
|
347 |
+
max_len=temporal_position_encoding_max_len
|
348 |
+
) if temporal_position_encoding else None
|
349 |
+
self.rescale_output_factor = rescale_output_factor
|
350 |
+
|
351 |
+
def set_use_memory_efficient_attention_xformers(
|
352 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
353 |
+
):
|
354 |
+
# disable motion module efficient xformers to avoid bad results, don't know why
|
355 |
+
# TODO: fix this bug
|
356 |
+
pass
|
357 |
+
|
358 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
359 |
+
# The `Attention` class can call different attention processors / attention functions
|
360 |
+
# here we simply pass along all tensors to the selected processor class
|
361 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
362 |
+
|
363 |
+
# add position encoding
|
364 |
+
if self.pos_encoder is not None:
|
365 |
+
hidden_states = self.pos_encoder(hidden_states)
|
366 |
+
if "pose_feature" in cross_attention_kwargs:
|
367 |
+
pose_feature = cross_attention_kwargs["pose_feature"]
|
368 |
+
if pose_feature.ndim == 5:
|
369 |
+
pose_feature = rearrange(pose_feature, "b c f h w -> (b h w) f c")
|
370 |
+
else:
|
371 |
+
assert pose_feature.ndim == 3
|
372 |
+
cross_attention_kwargs["pose_feature"] = pose_feature
|
373 |
+
|
374 |
+
if isinstance(self.processor, PoseAdaptorAttnProcessor):
|
375 |
+
return self.processor(
|
376 |
+
self,
|
377 |
+
hidden_states,
|
378 |
+
cross_attention_kwargs.pop('pose_feature'),
|
379 |
+
encoder_hidden_states=None,
|
380 |
+
attention_mask=attention_mask,
|
381 |
+
**cross_attention_kwargs,
|
382 |
+
)
|
383 |
+
elif hasattr(self.processor, "__call__"):
|
384 |
+
return self.processor.__call__(
|
385 |
+
self,
|
386 |
+
hidden_states,
|
387 |
+
encoder_hidden_states=None,
|
388 |
+
attention_mask=attention_mask,
|
389 |
+
**cross_attention_kwargs,
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
return self.processor(
|
393 |
+
self,
|
394 |
+
hidden_states,
|
395 |
+
encoder_hidden_states=None,
|
396 |
+
attention_mask=attention_mask,
|
397 |
+
**cross_attention_kwargs,
|
398 |
+
)
|
399 |
+
|
cameractrl/models/pose_adaptor.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange
|
5 |
+
from typing import List, Tuple
|
6 |
+
from cameractrl.models.motion_module import TemporalTransformerBlock
|
7 |
+
|
8 |
+
|
9 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
10 |
+
try:
|
11 |
+
params = tuple(parameter.parameters())
|
12 |
+
if len(params) > 0:
|
13 |
+
return params[0].dtype
|
14 |
+
|
15 |
+
buffers = tuple(parameter.buffers())
|
16 |
+
if len(buffers) > 0:
|
17 |
+
return buffers[0].dtype
|
18 |
+
|
19 |
+
except StopIteration:
|
20 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
21 |
+
|
22 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, torch.Tensor]]:
|
23 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
24 |
+
return tuples
|
25 |
+
|
26 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
27 |
+
first_tuple = next(gen)
|
28 |
+
return first_tuple[1].dtype
|
29 |
+
|
30 |
+
|
31 |
+
def conv_nd(dims, *args, **kwargs):
|
32 |
+
"""
|
33 |
+
Create a 1D, 2D, or 3D convolution module.
|
34 |
+
"""
|
35 |
+
if dims == 1:
|
36 |
+
return nn.Conv1d(*args, **kwargs)
|
37 |
+
elif dims == 2:
|
38 |
+
return nn.Conv2d(*args, **kwargs)
|
39 |
+
elif dims == 3:
|
40 |
+
return nn.Conv3d(*args, **kwargs)
|
41 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
42 |
+
|
43 |
+
|
44 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
45 |
+
"""
|
46 |
+
Create a 1D, 2D, or 3D average pooling module.
|
47 |
+
"""
|
48 |
+
if dims == 1:
|
49 |
+
return nn.AvgPool1d(*args, **kwargs)
|
50 |
+
elif dims == 2:
|
51 |
+
return nn.AvgPool2d(*args, **kwargs)
|
52 |
+
elif dims == 3:
|
53 |
+
return nn.AvgPool3d(*args, **kwargs)
|
54 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
55 |
+
|
56 |
+
|
57 |
+
class PoseAdaptor(nn.Module):
|
58 |
+
def __init__(self, unet, pose_encoder):
|
59 |
+
super().__init__()
|
60 |
+
self.unet = unet
|
61 |
+
self.pose_encoder = pose_encoder
|
62 |
+
|
63 |
+
def forward(self, noisy_latents, c_noise, encoder_hidden_states, added_time_ids, pose_embedding):
|
64 |
+
assert pose_embedding.ndim == 5
|
65 |
+
pose_embedding_features = self.pose_encoder(pose_embedding) # b c f h w
|
66 |
+
noise_pred = self.unet(noisy_latents,
|
67 |
+
c_noise,
|
68 |
+
encoder_hidden_states,
|
69 |
+
added_time_ids=added_time_ids,
|
70 |
+
pose_features=pose_embedding_features).sample
|
71 |
+
return noise_pred
|
72 |
+
|
73 |
+
|
74 |
+
class Downsample(nn.Module):
|
75 |
+
"""
|
76 |
+
A downsampling layer with an optional convolution.
|
77 |
+
:param channels: channels in the inputs and outputs.
|
78 |
+
:param use_conv: a bool determining if a convolution is applied.
|
79 |
+
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
80 |
+
downsampling occurs in the inner-two dimensions.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
84 |
+
super().__init__()
|
85 |
+
self.channels = channels
|
86 |
+
self.out_channels = out_channels or channels
|
87 |
+
self.use_conv = use_conv
|
88 |
+
self.dims = dims
|
89 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
90 |
+
if use_conv:
|
91 |
+
self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
92 |
+
else:
|
93 |
+
assert self.channels == self.out_channels
|
94 |
+
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
assert x.shape[1] == self.channels
|
98 |
+
return self.op(x)
|
99 |
+
|
100 |
+
|
101 |
+
class ResnetBlock(nn.Module):
|
102 |
+
|
103 |
+
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
|
104 |
+
super().__init__()
|
105 |
+
ps = ksize // 2
|
106 |
+
if in_c != out_c or sk == False:
|
107 |
+
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
108 |
+
else:
|
109 |
+
self.in_conv = None
|
110 |
+
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
|
111 |
+
self.act = nn.ReLU()
|
112 |
+
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
|
113 |
+
if sk == False:
|
114 |
+
self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
|
115 |
+
else:
|
116 |
+
self.skep = None
|
117 |
+
|
118 |
+
self.down = down
|
119 |
+
if self.down == True:
|
120 |
+
self.down_opt = Downsample(in_c, use_conv=use_conv)
|
121 |
+
|
122 |
+
def forward(self, x):
|
123 |
+
if self.down == True:
|
124 |
+
x = self.down_opt(x)
|
125 |
+
if self.in_conv is not None: # edit
|
126 |
+
x = self.in_conv(x)
|
127 |
+
|
128 |
+
h = self.block1(x)
|
129 |
+
h = self.act(h)
|
130 |
+
h = self.block2(h)
|
131 |
+
if self.skep is not None:
|
132 |
+
return h + self.skep(x)
|
133 |
+
else:
|
134 |
+
return h + x
|
135 |
+
|
136 |
+
|
137 |
+
class PositionalEncoding(nn.Module):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
d_model,
|
141 |
+
dropout=0.,
|
142 |
+
max_len=32,
|
143 |
+
):
|
144 |
+
super().__init__()
|
145 |
+
self.dropout = nn.Dropout(p=dropout)
|
146 |
+
position = torch.arange(max_len).unsqueeze(1)
|
147 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
148 |
+
pe = torch.zeros(1, max_len, d_model)
|
149 |
+
pe[0, :, 0::2, ...] = torch.sin(position * div_term)
|
150 |
+
pe[0, :, 1::2, ...] = torch.cos(position * div_term)
|
151 |
+
pe.unsqueeze_(-1).unsqueeze_(-1)
|
152 |
+
self.register_buffer('pe', pe)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
x = x + self.pe[:, :x.size(1), ...]
|
156 |
+
return self.dropout(x)
|
157 |
+
|
158 |
+
|
159 |
+
class CameraPoseEncoder(nn.Module):
|
160 |
+
|
161 |
+
def __init__(self,
|
162 |
+
downscale_factor,
|
163 |
+
channels=[320, 640, 1280, 1280],
|
164 |
+
nums_rb=3,
|
165 |
+
cin=64,
|
166 |
+
ksize=3,
|
167 |
+
sk=False,
|
168 |
+
use_conv=True,
|
169 |
+
compression_factor=1,
|
170 |
+
temporal_attention_nhead=8,
|
171 |
+
attention_block_types=("Temporal_Self", ),
|
172 |
+
temporal_position_encoding=False,
|
173 |
+
temporal_position_encoding_max_len=16,
|
174 |
+
rescale_output_factor=1.0):
|
175 |
+
super(CameraPoseEncoder, self).__init__()
|
176 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_factor)
|
177 |
+
self.channels = channels
|
178 |
+
self.nums_rb = nums_rb
|
179 |
+
self.encoder_down_conv_blocks = nn.ModuleList()
|
180 |
+
self.encoder_down_attention_blocks = nn.ModuleList()
|
181 |
+
for i in range(len(channels)):
|
182 |
+
conv_layers = nn.ModuleList()
|
183 |
+
temporal_attention_layers = nn.ModuleList()
|
184 |
+
for j in range(nums_rb):
|
185 |
+
if j == 0 and i != 0:
|
186 |
+
in_dim = channels[i - 1]
|
187 |
+
out_dim = int(channels[i] / compression_factor)
|
188 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=True, ksize=ksize, sk=sk, use_conv=use_conv)
|
189 |
+
elif j == 0:
|
190 |
+
in_dim = channels[0]
|
191 |
+
out_dim = int(channels[i] / compression_factor)
|
192 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
|
193 |
+
elif j == nums_rb - 1:
|
194 |
+
in_dim = channels[i] / compression_factor
|
195 |
+
out_dim = channels[i]
|
196 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
|
197 |
+
else:
|
198 |
+
in_dim = int(channels[i] / compression_factor)
|
199 |
+
out_dim = int(channels[i] / compression_factor)
|
200 |
+
conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
|
201 |
+
temporal_attention_layer = TemporalTransformerBlock(dim=out_dim,
|
202 |
+
num_attention_heads=temporal_attention_nhead,
|
203 |
+
attention_head_dim=int(out_dim / temporal_attention_nhead),
|
204 |
+
attention_block_types=attention_block_types,
|
205 |
+
dropout=0.0,
|
206 |
+
cross_attention_dim=None,
|
207 |
+
temporal_position_encoding=temporal_position_encoding,
|
208 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
209 |
+
rescale_output_factor=rescale_output_factor)
|
210 |
+
conv_layers.append(conv_layer)
|
211 |
+
temporal_attention_layers.append(temporal_attention_layer)
|
212 |
+
self.encoder_down_conv_blocks.append(conv_layers)
|
213 |
+
self.encoder_down_attention_blocks.append(temporal_attention_layers)
|
214 |
+
|
215 |
+
self.encoder_conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
|
216 |
+
|
217 |
+
@property
|
218 |
+
def dtype(self) -> torch.dtype:
|
219 |
+
"""
|
220 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
221 |
+
"""
|
222 |
+
return get_parameter_dtype(self)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
# unshuffle
|
226 |
+
bs = x.shape[0]
|
227 |
+
x = rearrange(x, "b f c h w -> (b f) c h w")
|
228 |
+
x = self.unshuffle(x)
|
229 |
+
# extract features
|
230 |
+
features = []
|
231 |
+
x = self.encoder_conv_in(x)
|
232 |
+
for res_block, attention_block in zip(self.encoder_down_conv_blocks, self.encoder_down_attention_blocks):
|
233 |
+
for res_layer, attention_layer in zip(res_block, attention_block):
|
234 |
+
x = res_layer(x)
|
235 |
+
h, w = x.shape[-2:]
|
236 |
+
x = rearrange(x, '(b f) c h w -> (b h w) f c', b=bs)
|
237 |
+
x = attention_layer(x)
|
238 |
+
x = rearrange(x, '(b h w) f c -> (b f) c h w', h=h, w=w)
|
239 |
+
features.append(rearrange(x, '(b f) c h w -> b c f h w', b=bs))
|
240 |
+
return features
|
cameractrl/models/transformer_temporal.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
from diffusers.models.transformer_temporal import TransformerTemporalModelOutput
|
6 |
+
from diffusers.models.attention import BasicTransformerBlock
|
7 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
8 |
+
from diffusers.models.resnet import AlphaBlender
|
9 |
+
from cameractrl.models.attention import TemporalPoseCondTransformerBlock
|
10 |
+
|
11 |
+
|
12 |
+
class TransformerSpatioTemporalModelPoseCond(nn.Module):
|
13 |
+
"""
|
14 |
+
A Transformer model for video-like data.
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
18 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
19 |
+
in_channels (`int`, *optional*):
|
20 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
21 |
+
out_channels (`int`, *optional*):
|
22 |
+
The number of channels in the output (specify if the input is **continuous**).
|
23 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
24 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
num_attention_heads: int = 16,
|
30 |
+
attention_head_dim: int = 88,
|
31 |
+
in_channels: int = 320,
|
32 |
+
out_channels: Optional[int] = None,
|
33 |
+
num_layers: int = 1,
|
34 |
+
cross_attention_dim: Optional[int] = None,
|
35 |
+
):
|
36 |
+
super().__init__()
|
37 |
+
self.num_attention_heads = num_attention_heads
|
38 |
+
self.attention_head_dim = attention_head_dim
|
39 |
+
|
40 |
+
inner_dim = num_attention_heads * attention_head_dim
|
41 |
+
self.inner_dim = inner_dim
|
42 |
+
|
43 |
+
# 2. Define input layers
|
44 |
+
self.in_channels = in_channels
|
45 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
46 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
47 |
+
|
48 |
+
# 3. Define transformers blocks
|
49 |
+
self.transformer_blocks = nn.ModuleList(
|
50 |
+
[
|
51 |
+
BasicTransformerBlock(
|
52 |
+
inner_dim,
|
53 |
+
num_attention_heads,
|
54 |
+
attention_head_dim,
|
55 |
+
cross_attention_dim=cross_attention_dim,
|
56 |
+
)
|
57 |
+
for d in range(num_layers)
|
58 |
+
]
|
59 |
+
)
|
60 |
+
|
61 |
+
time_mix_inner_dim = inner_dim
|
62 |
+
self.temporal_transformer_blocks = nn.ModuleList(
|
63 |
+
[
|
64 |
+
TemporalPoseCondTransformerBlock(
|
65 |
+
inner_dim,
|
66 |
+
time_mix_inner_dim,
|
67 |
+
num_attention_heads,
|
68 |
+
attention_head_dim,
|
69 |
+
cross_attention_dim=cross_attention_dim,
|
70 |
+
)
|
71 |
+
for _ in range(num_layers)
|
72 |
+
]
|
73 |
+
)
|
74 |
+
|
75 |
+
time_embed_dim = in_channels * 4
|
76 |
+
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
77 |
+
self.time_proj = Timesteps(in_channels, True, 0)
|
78 |
+
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
79 |
+
|
80 |
+
# 4. Define output layers
|
81 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
82 |
+
# TODO: should use out_channels for continuous projections
|
83 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
84 |
+
|
85 |
+
self.gradient_checkpointing = False
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
hidden_states: torch.Tensor, # [bs * frame, c, h, w]
|
90 |
+
encoder_hidden_states: Optional[torch.Tensor] = None, # [bs * frame, 1, c]
|
91 |
+
image_only_indicator: Optional[torch.Tensor] = None, # [bs, frame]
|
92 |
+
pose_feature: Optional[torch.Tensor] = None, # [bs, c, frame, h, w]
|
93 |
+
return_dict: bool = True,
|
94 |
+
):
|
95 |
+
"""
|
96 |
+
Args:
|
97 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
98 |
+
Input hidden_states.
|
99 |
+
num_frames (`int`):
|
100 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
101 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
102 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
103 |
+
self-attention.
|
104 |
+
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
105 |
+
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
106 |
+
images, 0 indicates that the input contains video frames.
|
107 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
108 |
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
|
109 |
+
plain tuple.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
113 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
114 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
115 |
+
"""
|
116 |
+
# 1. Input
|
117 |
+
batch_frames, _, height, width = hidden_states.shape
|
118 |
+
num_frames = image_only_indicator.shape[-1]
|
119 |
+
batch_size = batch_frames // num_frames
|
120 |
+
|
121 |
+
time_context = encoder_hidden_states # [bs * frame, 1, c]
|
122 |
+
time_context_first_timestep = time_context[None, :].reshape(
|
123 |
+
batch_size, num_frames, -1, time_context.shape[-1]
|
124 |
+
)[:, 0] # [bs, frame, c]
|
125 |
+
time_context = time_context_first_timestep[:, None].broadcast_to(
|
126 |
+
batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
|
127 |
+
) # [bs, h*w, 1, c]
|
128 |
+
time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1]) # [bs * h * w, 1, c]
|
129 |
+
|
130 |
+
residual = hidden_states
|
131 |
+
|
132 |
+
hidden_states = self.norm(hidden_states) # [bs * frame, c, h, w]
|
133 |
+
inner_dim = hidden_states.shape[1]
|
134 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) # [bs * frame, h * w, c]
|
135 |
+
hidden_states = self.proj_in(hidden_states) # [bs * frame, h * w, c]
|
136 |
+
|
137 |
+
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
138 |
+
num_frames_emb = num_frames_emb.repeat(batch_size, 1) # [bs, frame]
|
139 |
+
num_frames_emb = num_frames_emb.reshape(-1) # [bs * frame]
|
140 |
+
t_emb = self.time_proj(num_frames_emb) # [bs * frame, c]
|
141 |
+
|
142 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
143 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
144 |
+
# there might be better ways to encapsulate this.
|
145 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
146 |
+
|
147 |
+
emb = self.time_pos_embed(t_emb)
|
148 |
+
emb = emb[:, None, :] # [bs * frame, 1, c]
|
149 |
+
|
150 |
+
# 2. Blocks
|
151 |
+
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
152 |
+
if self.training and self.gradient_checkpointing:
|
153 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
154 |
+
block,
|
155 |
+
hidden_states,
|
156 |
+
None,
|
157 |
+
encoder_hidden_states,
|
158 |
+
None,
|
159 |
+
use_reentrant=False,
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
hidden_states = block(
|
163 |
+
hidden_states, # [bs * frame, h * w, c]
|
164 |
+
encoder_hidden_states=encoder_hidden_states, # [bs * frame, 1, c]
|
165 |
+
) # [bs * frame, h * w, c]
|
166 |
+
|
167 |
+
hidden_states_mix = hidden_states
|
168 |
+
hidden_states_mix = hidden_states_mix + emb
|
169 |
+
|
170 |
+
hidden_states_mix = temporal_block(
|
171 |
+
hidden_states_mix, # [bs * frame, h * w, c]
|
172 |
+
num_frames=num_frames,
|
173 |
+
encoder_hidden_states=time_context, # [bs * h * w, 1, c]
|
174 |
+
pose_feature=pose_feature
|
175 |
+
)
|
176 |
+
hidden_states = self.time_mixer(
|
177 |
+
x_spatial=hidden_states,
|
178 |
+
x_temporal=hidden_states_mix,
|
179 |
+
image_only_indicator=image_only_indicator,
|
180 |
+
)
|
181 |
+
|
182 |
+
# 3. Output
|
183 |
+
hidden_states = self.proj_out(hidden_states)
|
184 |
+
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
185 |
+
|
186 |
+
output = hidden_states + residual
|
187 |
+
|
188 |
+
if not return_dict:
|
189 |
+
return (output,)
|
190 |
+
|
191 |
+
return TransformerTemporalModelOutput(sample=output)
|
cameractrl/models/unet.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
|
7 |
+
from typing import List, Optional, Tuple, Union, Dict
|
8 |
+
|
9 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
10 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, CROSS_ATTENTION_PROCESSORS
|
11 |
+
from diffusers.models.modeling_utils import ModelMixin
|
12 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
13 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
14 |
+
from diffusers.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
|
15 |
+
|
16 |
+
from cameractrl.models.unet_3d_blocks import (
|
17 |
+
get_down_block,
|
18 |
+
get_up_block,
|
19 |
+
UNetMidBlockSpatioTemporalPoseCond
|
20 |
+
)
|
21 |
+
from cameractrl.models.attention_processor import XFormersAttnProcessor as CustomizedXFormerAttnProcessor
|
22 |
+
from cameractrl.models.attention_processor import PoseAdaptorXFormersAttnProcessor
|
23 |
+
|
24 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
25 |
+
from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor2_0 as PoseAdaptorAttnProcessor
|
26 |
+
from cameractrl.models.attention_processor import AttnProcessor2_0 as CustomizedAttnProcessor
|
27 |
+
else:
|
28 |
+
from cameractrl.models.attention_processor import PoseAdaptorAttnProcessor
|
29 |
+
from cameractrl.models.attention_processor import AttnProcessor as CustomizedAttnProcessor
|
30 |
+
|
31 |
+
class UNetSpatioTemporalConditionModelPoseCond(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
32 |
+
r"""
|
33 |
+
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
|
34 |
+
shaped output.
|
35 |
+
|
36 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
37 |
+
for all models (such as downloading or saving).
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
41 |
+
Height and width of input/output sample.
|
42 |
+
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
|
43 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
44 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
|
45 |
+
The tuple of downsample blocks to use.
|
46 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
|
47 |
+
The tuple of upsample blocks to use.
|
48 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
49 |
+
The tuple of output channels for each block.
|
50 |
+
addition_time_embed_dim: (`int`, defaults to 256):
|
51 |
+
Dimension to to encode the additional time ids.
|
52 |
+
projection_class_embeddings_input_dim (`int`, defaults to 768):
|
53 |
+
The dimension of the projection of encoded `added_time_ids`.
|
54 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
55 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
56 |
+
The dimension of the cross attention features.
|
57 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
58 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
59 |
+
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
60 |
+
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
61 |
+
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
62 |
+
The number of attention heads.
|
63 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
64 |
+
"""
|
65 |
+
|
66 |
+
_supports_gradient_checkpointing = True
|
67 |
+
|
68 |
+
@register_to_config
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
sample_size: Optional[int] = None,
|
72 |
+
in_channels: int = 8,
|
73 |
+
out_channels: int = 4,
|
74 |
+
down_block_types: Tuple[str] = (
|
75 |
+
"CrossAttnDownBlockSpatioTemporalPoseCond",
|
76 |
+
"CrossAttnDownBlockSpatioTemporalPoseCond",
|
77 |
+
"CrossAttnDownBlockSpatioTemporalPoseCond",
|
78 |
+
"DownBlockSpatioTemporal",
|
79 |
+
),
|
80 |
+
up_block_types: Tuple[str] = (
|
81 |
+
"UpBlockSpatioTemporal",
|
82 |
+
"CrossAttnUpBlockSpatioTemporalPoseCond",
|
83 |
+
"CrossAttnUpBlockSpatioTemporalPoseCond",
|
84 |
+
"CrossAttnUpBlockSpatioTemporalPoseCond",
|
85 |
+
),
|
86 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
87 |
+
addition_time_embed_dim: int = 256,
|
88 |
+
projection_class_embeddings_input_dim: int = 768,
|
89 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
90 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
91 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
92 |
+
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
|
93 |
+
num_frames: int = 25,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.sample_size = sample_size
|
98 |
+
|
99 |
+
# Check inputs
|
100 |
+
if len(down_block_types) != len(up_block_types):
|
101 |
+
raise ValueError(
|
102 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
103 |
+
)
|
104 |
+
|
105 |
+
if len(block_out_channels) != len(down_block_types):
|
106 |
+
raise ValueError(
|
107 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
108 |
+
)
|
109 |
+
|
110 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
111 |
+
raise ValueError(
|
112 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
113 |
+
)
|
114 |
+
|
115 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
116 |
+
raise ValueError(
|
117 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
118 |
+
)
|
119 |
+
|
120 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
121 |
+
raise ValueError(
|
122 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
123 |
+
)
|
124 |
+
|
125 |
+
# input
|
126 |
+
self.conv_in = nn.Conv2d(
|
127 |
+
in_channels,
|
128 |
+
block_out_channels[0],
|
129 |
+
kernel_size=3,
|
130 |
+
padding=1,
|
131 |
+
)
|
132 |
+
|
133 |
+
# time
|
134 |
+
time_embed_dim = block_out_channels[0] * 4
|
135 |
+
|
136 |
+
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
|
137 |
+
timestep_input_dim = block_out_channels[0]
|
138 |
+
|
139 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
140 |
+
|
141 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
|
142 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
143 |
+
|
144 |
+
self.down_blocks = nn.ModuleList([])
|
145 |
+
self.up_blocks = nn.ModuleList([])
|
146 |
+
|
147 |
+
if isinstance(num_attention_heads, int):
|
148 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
149 |
+
|
150 |
+
if isinstance(cross_attention_dim, int):
|
151 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
152 |
+
|
153 |
+
if isinstance(layers_per_block, int):
|
154 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
155 |
+
|
156 |
+
if isinstance(transformer_layers_per_block, int):
|
157 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
158 |
+
|
159 |
+
blocks_time_embed_dim = time_embed_dim
|
160 |
+
|
161 |
+
# down
|
162 |
+
output_channel = block_out_channels[0]
|
163 |
+
for i, down_block_type in enumerate(down_block_types):
|
164 |
+
input_channel = output_channel
|
165 |
+
output_channel = block_out_channels[i]
|
166 |
+
is_final_block = i == len(block_out_channels) - 1
|
167 |
+
|
168 |
+
down_block = get_down_block(
|
169 |
+
down_block_type,
|
170 |
+
num_layers=layers_per_block[i],
|
171 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
172 |
+
in_channels=input_channel,
|
173 |
+
out_channels=output_channel,
|
174 |
+
temb_channels=blocks_time_embed_dim,
|
175 |
+
add_downsample=not is_final_block,
|
176 |
+
resnet_eps=1e-5,
|
177 |
+
cross_attention_dim=cross_attention_dim[i],
|
178 |
+
num_attention_heads=num_attention_heads[i],
|
179 |
+
resnet_act_fn="silu",
|
180 |
+
)
|
181 |
+
self.down_blocks.append(down_block)
|
182 |
+
|
183 |
+
# mid
|
184 |
+
self.mid_block = UNetMidBlockSpatioTemporalPoseCond(
|
185 |
+
block_out_channels[-1],
|
186 |
+
temb_channels=blocks_time_embed_dim,
|
187 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
188 |
+
cross_attention_dim=cross_attention_dim[-1],
|
189 |
+
num_attention_heads=num_attention_heads[-1],
|
190 |
+
)
|
191 |
+
|
192 |
+
# count how many layers upsample the images
|
193 |
+
self.num_upsamplers = 0
|
194 |
+
|
195 |
+
# up
|
196 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
197 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
198 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
199 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
200 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
201 |
+
|
202 |
+
output_channel = reversed_block_out_channels[0]
|
203 |
+
for i, up_block_type in enumerate(up_block_types):
|
204 |
+
is_final_block = i == len(block_out_channels) - 1
|
205 |
+
|
206 |
+
prev_output_channel = output_channel
|
207 |
+
output_channel = reversed_block_out_channels[i]
|
208 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
209 |
+
|
210 |
+
# add upsample block for all BUT final layer
|
211 |
+
if not is_final_block:
|
212 |
+
add_upsample = True
|
213 |
+
self.num_upsamplers += 1
|
214 |
+
else:
|
215 |
+
add_upsample = False
|
216 |
+
|
217 |
+
up_block = get_up_block(
|
218 |
+
up_block_type,
|
219 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
220 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
221 |
+
in_channels=input_channel,
|
222 |
+
out_channels=output_channel,
|
223 |
+
prev_output_channel=prev_output_channel,
|
224 |
+
temb_channels=blocks_time_embed_dim,
|
225 |
+
add_upsample=add_upsample,
|
226 |
+
resnet_eps=1e-5,
|
227 |
+
resolution_idx=i,
|
228 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
229 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
230 |
+
resnet_act_fn="silu",
|
231 |
+
)
|
232 |
+
self.up_blocks.append(up_block)
|
233 |
+
prev_output_channel = output_channel
|
234 |
+
|
235 |
+
# out
|
236 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
|
237 |
+
self.conv_act = nn.SiLU()
|
238 |
+
|
239 |
+
self.conv_out = nn.Conv2d(
|
240 |
+
block_out_channels[0],
|
241 |
+
out_channels,
|
242 |
+
kernel_size=3,
|
243 |
+
padding=1,
|
244 |
+
)
|
245 |
+
|
246 |
+
@property
|
247 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
248 |
+
r"""
|
249 |
+
Returns:
|
250 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
251 |
+
indexed by its weight name.
|
252 |
+
"""
|
253 |
+
# set recursively
|
254 |
+
processors = {}
|
255 |
+
|
256 |
+
def fn_recursive_add_processors(
|
257 |
+
name: str,
|
258 |
+
module: torch.nn.Module,
|
259 |
+
processors: Dict[str, AttentionProcessor],
|
260 |
+
):
|
261 |
+
if hasattr(module, "get_processor"):
|
262 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
263 |
+
|
264 |
+
for sub_name, child in module.named_children():
|
265 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
266 |
+
|
267 |
+
return processors
|
268 |
+
|
269 |
+
for name, module in self.named_children():
|
270 |
+
fn_recursive_add_processors(name, module, processors)
|
271 |
+
|
272 |
+
return processors
|
273 |
+
|
274 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
275 |
+
r"""
|
276 |
+
Sets the attention processor to use to compute attention.
|
277 |
+
|
278 |
+
Parameters:
|
279 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
280 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
281 |
+
for **all** `Attention` layers.
|
282 |
+
|
283 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
284 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
285 |
+
|
286 |
+
"""
|
287 |
+
count = len(self.attn_processors.keys())
|
288 |
+
|
289 |
+
if isinstance(processor, dict) and len(processor) != count:
|
290 |
+
raise ValueError(
|
291 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
292 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
293 |
+
)
|
294 |
+
|
295 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
296 |
+
if hasattr(module, "set_processor"):
|
297 |
+
if not isinstance(processor, dict):
|
298 |
+
module.set_processor(processor)
|
299 |
+
else:
|
300 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
301 |
+
|
302 |
+
for sub_name, child in module.named_children():
|
303 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
304 |
+
|
305 |
+
for name, module in self.named_children():
|
306 |
+
fn_recursive_attn_processor(name, module, processor)
|
307 |
+
|
308 |
+
def set_default_attn_processor(self):
|
309 |
+
"""
|
310 |
+
Disables custom attention processors and sets the default attention implementation.
|
311 |
+
"""
|
312 |
+
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
313 |
+
processor = AttnProcessor()
|
314 |
+
else:
|
315 |
+
raise ValueError(
|
316 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
317 |
+
)
|
318 |
+
|
319 |
+
self.set_attn_processor(processor)
|
320 |
+
|
321 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
322 |
+
if hasattr(module, "gradient_checkpointing"):
|
323 |
+
module.gradient_checkpointing = value
|
324 |
+
|
325 |
+
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
326 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
327 |
+
"""
|
328 |
+
Sets the attention processor to use [feed forward
|
329 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
330 |
+
|
331 |
+
Parameters:
|
332 |
+
chunk_size (`int`, *optional*):
|
333 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
334 |
+
over each tensor of dim=`dim`.
|
335 |
+
dim (`int`, *optional*, defaults to `0`):
|
336 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
337 |
+
or dim=1 (sequence length).
|
338 |
+
"""
|
339 |
+
if dim not in [0, 1]:
|
340 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
341 |
+
|
342 |
+
# By default chunk size is 1
|
343 |
+
chunk_size = chunk_size or 1
|
344 |
+
|
345 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
346 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
347 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
348 |
+
|
349 |
+
for child in module.children():
|
350 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
351 |
+
|
352 |
+
for module in self.children():
|
353 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
354 |
+
|
355 |
+
def set_pose_cond_attn_processor(self,
|
356 |
+
add_spatial=False,
|
357 |
+
add_temporal=False,
|
358 |
+
enable_xformers=False,
|
359 |
+
attn_processor_name='attn1',
|
360 |
+
pose_feature_dimensions=[320, 640, 1280, 1280],
|
361 |
+
**attention_processor_kwargs):
|
362 |
+
all_attn_processors = {}
|
363 |
+
set_processor_names = attn_processor_name.split(',')
|
364 |
+
if add_spatial:
|
365 |
+
for processor_key in self.attn_processors.keys():
|
366 |
+
if 'temporal' in processor_key:
|
367 |
+
continue
|
368 |
+
processor_name = processor_key.split('.')[-2]
|
369 |
+
cross_attention_dim = None if processor_name == 'attn1' else self.config.cross_attention_dim
|
370 |
+
if processor_key.startswith("mid_block"):
|
371 |
+
hidden_size = self.config.block_out_channels[-1]
|
372 |
+
block_id = -1
|
373 |
+
add_pose_adaptor = processor_name in set_processor_names
|
374 |
+
pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
|
375 |
+
elif processor_key.startswith("up_blocks"):
|
376 |
+
block_id = int(processor_key[len("up_blocks.")])
|
377 |
+
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
|
378 |
+
add_pose_adaptor = processor_name in set_processor_names
|
379 |
+
pose_feature_dim = list(reversed(pose_feature_dimensions))[block_id] if add_pose_adaptor else None
|
380 |
+
else:
|
381 |
+
block_id = int(processor_key[len("down_blocks.")])
|
382 |
+
hidden_size = self.config.block_out_channels[block_id]
|
383 |
+
add_pose_adaptor = processor_name in set_processor_names
|
384 |
+
pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
|
385 |
+
if add_pose_adaptor and enable_xformers:
|
386 |
+
all_attn_processors[processor_key] = PoseAdaptorXFormersAttnProcessor(hidden_size=hidden_size,
|
387 |
+
pose_feature_dim=pose_feature_dim,
|
388 |
+
cross_attention_dim=cross_attention_dim,
|
389 |
+
**attention_processor_kwargs)
|
390 |
+
elif add_pose_adaptor:
|
391 |
+
all_attn_processors[processor_key] = PoseAdaptorAttnProcessor(hidden_size=hidden_size,
|
392 |
+
pose_feature_dim=pose_feature_dim,
|
393 |
+
cross_attention_dim=cross_attention_dim,
|
394 |
+
**attention_processor_kwargs)
|
395 |
+
elif enable_xformers:
|
396 |
+
all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
|
397 |
+
else:
|
398 |
+
all_attn_processors[processor_key] = CustomizedAttnProcessor()
|
399 |
+
else:
|
400 |
+
for processor_key in self.attn_processors.keys():
|
401 |
+
if 'temporal' not in processor_key and enable_xformers:
|
402 |
+
all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
|
403 |
+
elif 'temporal' not in processor_key:
|
404 |
+
all_attn_processors[processor_key] = CustomizedAttnProcessor()
|
405 |
+
|
406 |
+
if add_temporal:
|
407 |
+
for processor_key in self.attn_processors.keys():
|
408 |
+
if 'temporal' not in processor_key:
|
409 |
+
continue
|
410 |
+
processor_name = processor_key.split('.')[-2]
|
411 |
+
cross_attention_dim = None if processor_name == 'attn1' else self.config.cross_attention_dim
|
412 |
+
if processor_key.startswith("mid_block"):
|
413 |
+
hidden_size = self.config.block_out_channels[-1]
|
414 |
+
block_id = -1
|
415 |
+
add_pose_adaptor = processor_name in set_processor_names
|
416 |
+
pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
|
417 |
+
elif processor_key.startswith("up_blocks"):
|
418 |
+
block_id = int(processor_key[len("up_blocks.")])
|
419 |
+
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
|
420 |
+
add_pose_adaptor = (processor_name in set_processor_names)
|
421 |
+
pose_feature_dim = list(reversed(pose_feature_dimensions))[block_id] if add_pose_adaptor else None
|
422 |
+
else:
|
423 |
+
block_id = int(processor_key[len("down_blocks.")])
|
424 |
+
hidden_size = self.config.block_out_channels[block_id]
|
425 |
+
add_pose_adaptor = processor_name in set_processor_names
|
426 |
+
pose_feature_dim = pose_feature_dimensions[block_id] if add_pose_adaptor else None
|
427 |
+
if add_pose_adaptor and enable_xformers:
|
428 |
+
all_attn_processors[processor_key] = PoseAdaptorAttnProcessor(hidden_size=hidden_size,
|
429 |
+
pose_feature_dim=pose_feature_dim,
|
430 |
+
cross_attention_dim=cross_attention_dim,
|
431 |
+
**attention_processor_kwargs)
|
432 |
+
elif add_pose_adaptor:
|
433 |
+
all_attn_processors[processor_key] = PoseAdaptorAttnProcessor(hidden_size=hidden_size,
|
434 |
+
pose_feature_dim=pose_feature_dim,
|
435 |
+
cross_attention_dim=cross_attention_dim,
|
436 |
+
**attention_processor_kwargs)
|
437 |
+
elif enable_xformers:
|
438 |
+
all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
|
439 |
+
else:
|
440 |
+
all_attn_processors[processor_key] = CustomizedAttnProcessor()
|
441 |
+
else:
|
442 |
+
for processor_key in self.attn_processors.keys():
|
443 |
+
if 'temporal' in processor_key and enable_xformers:
|
444 |
+
all_attn_processors[processor_key] = CustomizedXFormerAttnProcessor()
|
445 |
+
elif 'temporal' in processor_key:
|
446 |
+
all_attn_processors[processor_key] = CustomizedAttnProcessor()
|
447 |
+
|
448 |
+
self.set_attn_processor(all_attn_processors)
|
449 |
+
|
450 |
+
def forward(
|
451 |
+
self,
|
452 |
+
sample: torch.FloatTensor,
|
453 |
+
timestep: Union[torch.Tensor, float, int],
|
454 |
+
encoder_hidden_states: torch.Tensor,
|
455 |
+
added_time_ids: torch.Tensor,
|
456 |
+
pose_features: List[torch.Tensor] = None,
|
457 |
+
return_dict: bool = True,
|
458 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
459 |
+
r"""
|
460 |
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
461 |
+
|
462 |
+
Args:
|
463 |
+
sample (`torch.FloatTensor`):
|
464 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
465 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
466 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
467 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
468 |
+
added_time_ids: (`torch.FloatTensor`):
|
469 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
470 |
+
embeddings and added to the time embeddings.
|
471 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
472 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
|
473 |
+
tuple.
|
474 |
+
Returns:
|
475 |
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
476 |
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
|
477 |
+
a `tuple` is returned where the first element is the sample tensor.
|
478 |
+
"""
|
479 |
+
# 1. time
|
480 |
+
timesteps = timestep
|
481 |
+
if not torch.is_tensor(timesteps):
|
482 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
483 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
484 |
+
is_mps = sample.device.type == "mps"
|
485 |
+
if isinstance(timestep, float):
|
486 |
+
dtype = torch.float32 if is_mps else torch.float64
|
487 |
+
else:
|
488 |
+
dtype = torch.int32 if is_mps else torch.int64
|
489 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
490 |
+
elif len(timesteps.shape) == 0:
|
491 |
+
timesteps = timesteps[None].to(sample.device)
|
492 |
+
|
493 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
494 |
+
batch_size, num_frames = sample.shape[:2]
|
495 |
+
timesteps = timesteps.expand(batch_size)
|
496 |
+
|
497 |
+
t_emb = self.time_proj(timesteps)
|
498 |
+
|
499 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
500 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
501 |
+
# there might be better ways to encapsulate this.
|
502 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
503 |
+
|
504 |
+
emb = self.time_embedding(t_emb)
|
505 |
+
|
506 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
507 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
508 |
+
time_embeds = time_embeds.to(emb.dtype)
|
509 |
+
aug_emb = self.add_embedding(time_embeds)
|
510 |
+
emb = emb + aug_emb
|
511 |
+
|
512 |
+
# Flatten the batch and frames dimensions
|
513 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
514 |
+
sample = sample.flatten(0, 1)
|
515 |
+
# Repeat the embeddings num_video_frames times
|
516 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
517 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
518 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
519 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
520 |
+
|
521 |
+
# 2. pre-process
|
522 |
+
sample = self.conv_in(sample)
|
523 |
+
|
524 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
525 |
+
|
526 |
+
down_block_res_samples = (sample,)
|
527 |
+
for block_idx, downsample_block in enumerate(self.down_blocks):
|
528 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
529 |
+
sample, res_samples = downsample_block(
|
530 |
+
hidden_states=sample,
|
531 |
+
temb=emb,
|
532 |
+
encoder_hidden_states=encoder_hidden_states,
|
533 |
+
image_only_indicator=image_only_indicator,
|
534 |
+
pose_feature=pose_features[block_idx]
|
535 |
+
)
|
536 |
+
else:
|
537 |
+
sample, res_samples = downsample_block(
|
538 |
+
hidden_states=sample,
|
539 |
+
temb=emb,
|
540 |
+
image_only_indicator=image_only_indicator,
|
541 |
+
)
|
542 |
+
|
543 |
+
down_block_res_samples += res_samples
|
544 |
+
|
545 |
+
# 4. mid
|
546 |
+
sample = self.mid_block(
|
547 |
+
hidden_states=sample,
|
548 |
+
temb=emb,
|
549 |
+
encoder_hidden_states=encoder_hidden_states,
|
550 |
+
image_only_indicator=image_only_indicator,
|
551 |
+
pose_feature=pose_features[-1]
|
552 |
+
)
|
553 |
+
|
554 |
+
# 5. up
|
555 |
+
for block_idx, upsample_block in enumerate(self.up_blocks):
|
556 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
557 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
558 |
+
|
559 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
560 |
+
sample = upsample_block(
|
561 |
+
hidden_states=sample,
|
562 |
+
temb=emb,
|
563 |
+
res_hidden_states_tuple=res_samples,
|
564 |
+
encoder_hidden_states=encoder_hidden_states,
|
565 |
+
image_only_indicator=image_only_indicator,
|
566 |
+
pose_feature=pose_features[-(block_idx + 1)]
|
567 |
+
)
|
568 |
+
else:
|
569 |
+
sample = upsample_block(
|
570 |
+
hidden_states=sample,
|
571 |
+
temb=emb,
|
572 |
+
res_hidden_states_tuple=res_samples,
|
573 |
+
image_only_indicator=image_only_indicator,
|
574 |
+
)
|
575 |
+
|
576 |
+
# 6. post-process
|
577 |
+
sample = self.conv_norm_out(sample)
|
578 |
+
sample = self.conv_act(sample)
|
579 |
+
sample = self.conv_out(sample)
|
580 |
+
|
581 |
+
# 7. Reshape back to original shape
|
582 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
583 |
+
|
584 |
+
if not return_dict:
|
585 |
+
return (sample,)
|
586 |
+
|
587 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
cameractrl/models/unet_3d_blocks.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Union, Tuple, Optional, Dict, Any
|
4 |
+
from diffusers.utils import is_torch_version
|
5 |
+
from diffusers.models.resnet import (
|
6 |
+
Downsample2D,
|
7 |
+
SpatioTemporalResBlock,
|
8 |
+
Upsample2D
|
9 |
+
)
|
10 |
+
from diffusers.models.unet_3d_blocks import (
|
11 |
+
DownBlockSpatioTemporal,
|
12 |
+
UpBlockSpatioTemporal,
|
13 |
+
)
|
14 |
+
|
15 |
+
from cameractrl.models.transformer_temporal import TransformerSpatioTemporalModelPoseCond
|
16 |
+
|
17 |
+
|
18 |
+
def get_down_block(
|
19 |
+
down_block_type: str,
|
20 |
+
num_layers: int,
|
21 |
+
in_channels: int,
|
22 |
+
out_channels: int,
|
23 |
+
temb_channels: int,
|
24 |
+
add_downsample: bool,
|
25 |
+
num_attention_heads: int,
|
26 |
+
cross_attention_dim: Optional[int] = None,
|
27 |
+
transformer_layers_per_block: int = 1,
|
28 |
+
**kwargs,
|
29 |
+
) -> Union[
|
30 |
+
"DownBlockSpatioTemporal",
|
31 |
+
"CrossAttnDownBlockSpatioTemporalPoseCond",
|
32 |
+
]:
|
33 |
+
if down_block_type == "DownBlockSpatioTemporal":
|
34 |
+
# added for SDV
|
35 |
+
return DownBlockSpatioTemporal(
|
36 |
+
num_layers=num_layers,
|
37 |
+
in_channels=in_channels,
|
38 |
+
out_channels=out_channels,
|
39 |
+
temb_channels=temb_channels,
|
40 |
+
add_downsample=add_downsample,
|
41 |
+
)
|
42 |
+
elif down_block_type == "CrossAttnDownBlockSpatioTemporalPoseCond":
|
43 |
+
# added for SDV
|
44 |
+
if cross_attention_dim is None:
|
45 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
|
46 |
+
return CrossAttnDownBlockSpatioTemporalPoseCond(
|
47 |
+
in_channels=in_channels,
|
48 |
+
out_channels=out_channels,
|
49 |
+
temb_channels=temb_channels,
|
50 |
+
num_layers=num_layers,
|
51 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
52 |
+
add_downsample=add_downsample,
|
53 |
+
cross_attention_dim=cross_attention_dim,
|
54 |
+
num_attention_heads=num_attention_heads,
|
55 |
+
)
|
56 |
+
|
57 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
58 |
+
|
59 |
+
|
60 |
+
def get_up_block(
|
61 |
+
up_block_type: str,
|
62 |
+
num_layers: int,
|
63 |
+
in_channels: int,
|
64 |
+
out_channels: int,
|
65 |
+
prev_output_channel: int,
|
66 |
+
temb_channels: int,
|
67 |
+
add_upsample: bool,
|
68 |
+
num_attention_heads: int,
|
69 |
+
resolution_idx: Optional[int] = None,
|
70 |
+
cross_attention_dim: Optional[int] = None,
|
71 |
+
transformer_layers_per_block: int = 1,
|
72 |
+
**kwargs,
|
73 |
+
) -> Union[
|
74 |
+
"UpBlockSpatioTemporal",
|
75 |
+
"CrossAttnUpBlockSpatioTemporalPoseCond",
|
76 |
+
]:
|
77 |
+
if up_block_type == "UpBlockSpatioTemporal":
|
78 |
+
# added for SDV
|
79 |
+
return UpBlockSpatioTemporal(
|
80 |
+
num_layers=num_layers,
|
81 |
+
in_channels=in_channels,
|
82 |
+
out_channels=out_channels,
|
83 |
+
prev_output_channel=prev_output_channel,
|
84 |
+
temb_channels=temb_channels,
|
85 |
+
resolution_idx=resolution_idx,
|
86 |
+
add_upsample=add_upsample,
|
87 |
+
)
|
88 |
+
elif up_block_type == "CrossAttnUpBlockSpatioTemporalPoseCond":
|
89 |
+
# added for SDV
|
90 |
+
if cross_attention_dim is None:
|
91 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
|
92 |
+
return CrossAttnUpBlockSpatioTemporalPoseCond(
|
93 |
+
in_channels=in_channels,
|
94 |
+
out_channels=out_channels,
|
95 |
+
prev_output_channel=prev_output_channel,
|
96 |
+
temb_channels=temb_channels,
|
97 |
+
num_layers=num_layers,
|
98 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
99 |
+
add_upsample=add_upsample,
|
100 |
+
cross_attention_dim=cross_attention_dim,
|
101 |
+
num_attention_heads=num_attention_heads,
|
102 |
+
resolution_idx=resolution_idx,
|
103 |
+
)
|
104 |
+
|
105 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
106 |
+
|
107 |
+
|
108 |
+
class CrossAttnDownBlockSpatioTemporalPoseCond(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
in_channels: int,
|
112 |
+
out_channels: int,
|
113 |
+
temb_channels: int,
|
114 |
+
num_layers: int = 1,
|
115 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
116 |
+
num_attention_heads: int = 1,
|
117 |
+
cross_attention_dim: int = 1280,
|
118 |
+
add_downsample: bool = True,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
resnets = []
|
122 |
+
attentions = []
|
123 |
+
|
124 |
+
self.has_cross_attention = True
|
125 |
+
self.num_attention_heads = num_attention_heads
|
126 |
+
if isinstance(transformer_layers_per_block, int):
|
127 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
128 |
+
|
129 |
+
for i in range(num_layers):
|
130 |
+
in_channels = in_channels if i == 0 else out_channels
|
131 |
+
resnets.append(
|
132 |
+
SpatioTemporalResBlock(
|
133 |
+
in_channels=in_channels,
|
134 |
+
out_channels=out_channels,
|
135 |
+
temb_channels=temb_channels,
|
136 |
+
eps=1e-6,
|
137 |
+
)
|
138 |
+
)
|
139 |
+
attentions.append(
|
140 |
+
TransformerSpatioTemporalModelPoseCond(
|
141 |
+
num_attention_heads,
|
142 |
+
out_channels // num_attention_heads,
|
143 |
+
in_channels=out_channels,
|
144 |
+
num_layers=transformer_layers_per_block[i],
|
145 |
+
cross_attention_dim=cross_attention_dim,
|
146 |
+
)
|
147 |
+
)
|
148 |
+
|
149 |
+
self.attentions = nn.ModuleList(attentions)
|
150 |
+
self.resnets = nn.ModuleList(resnets)
|
151 |
+
|
152 |
+
if add_downsample:
|
153 |
+
self.downsamplers = nn.ModuleList(
|
154 |
+
[
|
155 |
+
Downsample2D(
|
156 |
+
out_channels,
|
157 |
+
use_conv=True,
|
158 |
+
out_channels=out_channels,
|
159 |
+
padding=1,
|
160 |
+
name="op",
|
161 |
+
)
|
162 |
+
]
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
self.downsamplers = None
|
166 |
+
|
167 |
+
self.gradient_checkpointing = False
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
hidden_states: torch.FloatTensor, # [bs * frame, c, h, w]
|
172 |
+
temb: Optional[torch.FloatTensor] = None, # [bs * frame, c]
|
173 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None, # [bs * frame, 1, c]
|
174 |
+
image_only_indicator: Optional[torch.Tensor] = None, # [bs, frame]
|
175 |
+
pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
|
176 |
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
177 |
+
output_states = ()
|
178 |
+
|
179 |
+
blocks = list(zip(self.resnets, self.attentions))
|
180 |
+
for resnet, attn in blocks:
|
181 |
+
if self.training and self.gradient_checkpointing: # TODO
|
182 |
+
|
183 |
+
def create_custom_forward(module, return_dict=None):
|
184 |
+
def custom_forward(*inputs):
|
185 |
+
if return_dict is not None:
|
186 |
+
return module(*inputs, return_dict=return_dict)
|
187 |
+
else:
|
188 |
+
return module(*inputs)
|
189 |
+
|
190 |
+
return custom_forward
|
191 |
+
|
192 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
193 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
194 |
+
create_custom_forward(resnet),
|
195 |
+
hidden_states,
|
196 |
+
temb,
|
197 |
+
image_only_indicator,
|
198 |
+
**ckpt_kwargs,
|
199 |
+
)
|
200 |
+
|
201 |
+
hidden_states = attn(
|
202 |
+
hidden_states,
|
203 |
+
encoder_hidden_states=encoder_hidden_states,
|
204 |
+
image_only_indicator=image_only_indicator,
|
205 |
+
return_dict=False,
|
206 |
+
)[0]
|
207 |
+
else:
|
208 |
+
hidden_states = resnet(
|
209 |
+
hidden_states,
|
210 |
+
temb,
|
211 |
+
image_only_indicator=image_only_indicator,
|
212 |
+
) # [bs * frame, c, h, w]
|
213 |
+
hidden_states = attn(
|
214 |
+
hidden_states,
|
215 |
+
encoder_hidden_states=encoder_hidden_states,
|
216 |
+
image_only_indicator=image_only_indicator,
|
217 |
+
pose_feature=pose_feature,
|
218 |
+
return_dict=False,
|
219 |
+
)[0]
|
220 |
+
|
221 |
+
output_states = output_states + (hidden_states,)
|
222 |
+
|
223 |
+
if self.downsamplers is not None:
|
224 |
+
for downsampler in self.downsamplers:
|
225 |
+
hidden_states = downsampler(hidden_states)
|
226 |
+
|
227 |
+
output_states = output_states + (hidden_states,)
|
228 |
+
|
229 |
+
return hidden_states, output_states
|
230 |
+
|
231 |
+
|
232 |
+
class UNetMidBlockSpatioTemporalPoseCond(nn.Module):
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
in_channels: int,
|
236 |
+
temb_channels: int,
|
237 |
+
num_layers: int = 1,
|
238 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
239 |
+
num_attention_heads: int = 1,
|
240 |
+
cross_attention_dim: int = 1280,
|
241 |
+
):
|
242 |
+
super().__init__()
|
243 |
+
|
244 |
+
self.has_cross_attention = True
|
245 |
+
self.num_attention_heads = num_attention_heads
|
246 |
+
|
247 |
+
# support for variable transformer layers per block
|
248 |
+
if isinstance(transformer_layers_per_block, int):
|
249 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
250 |
+
|
251 |
+
# there is always at least one resnet
|
252 |
+
resnets = [
|
253 |
+
SpatioTemporalResBlock(
|
254 |
+
in_channels=in_channels,
|
255 |
+
out_channels=in_channels,
|
256 |
+
temb_channels=temb_channels,
|
257 |
+
eps=1e-5,
|
258 |
+
)
|
259 |
+
]
|
260 |
+
attentions = []
|
261 |
+
|
262 |
+
for i in range(num_layers):
|
263 |
+
attentions.append(
|
264 |
+
TransformerSpatioTemporalModelPoseCond(
|
265 |
+
num_attention_heads,
|
266 |
+
in_channels // num_attention_heads,
|
267 |
+
in_channels=in_channels,
|
268 |
+
num_layers=transformer_layers_per_block[i],
|
269 |
+
cross_attention_dim=cross_attention_dim,
|
270 |
+
)
|
271 |
+
)
|
272 |
+
|
273 |
+
resnets.append(
|
274 |
+
SpatioTemporalResBlock(
|
275 |
+
in_channels=in_channels,
|
276 |
+
out_channels=in_channels,
|
277 |
+
temb_channels=temb_channels,
|
278 |
+
eps=1e-5,
|
279 |
+
)
|
280 |
+
)
|
281 |
+
|
282 |
+
self.attentions = nn.ModuleList(attentions)
|
283 |
+
self.resnets = nn.ModuleList(resnets)
|
284 |
+
|
285 |
+
self.gradient_checkpointing = False
|
286 |
+
|
287 |
+
def forward(
|
288 |
+
self,
|
289 |
+
hidden_states: torch.FloatTensor,
|
290 |
+
temb: Optional[torch.FloatTensor] = None,
|
291 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
292 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
293 |
+
pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
|
294 |
+
) -> torch.FloatTensor:
|
295 |
+
hidden_states = self.resnets[0](
|
296 |
+
hidden_states,
|
297 |
+
temb,
|
298 |
+
image_only_indicator=image_only_indicator,
|
299 |
+
)
|
300 |
+
|
301 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
302 |
+
if self.training and self.gradient_checkpointing: # TODO
|
303 |
+
|
304 |
+
def create_custom_forward(module, return_dict=None):
|
305 |
+
def custom_forward(*inputs):
|
306 |
+
if return_dict is not None:
|
307 |
+
return module(*inputs, return_dict=return_dict)
|
308 |
+
else:
|
309 |
+
return module(*inputs)
|
310 |
+
|
311 |
+
return custom_forward
|
312 |
+
|
313 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
314 |
+
hidden_states = attn(
|
315 |
+
hidden_states,
|
316 |
+
encoder_hidden_states=encoder_hidden_states,
|
317 |
+
image_only_indicator=image_only_indicator,
|
318 |
+
return_dict=False,
|
319 |
+
)[0]
|
320 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
321 |
+
create_custom_forward(resnet),
|
322 |
+
hidden_states,
|
323 |
+
temb,
|
324 |
+
image_only_indicator,
|
325 |
+
**ckpt_kwargs,
|
326 |
+
)
|
327 |
+
else:
|
328 |
+
hidden_states = attn(
|
329 |
+
hidden_states,
|
330 |
+
encoder_hidden_states=encoder_hidden_states,
|
331 |
+
image_only_indicator=image_only_indicator,
|
332 |
+
pose_feature=pose_feature,
|
333 |
+
return_dict=False,
|
334 |
+
)[0]
|
335 |
+
hidden_states = resnet(
|
336 |
+
hidden_states,
|
337 |
+
temb,
|
338 |
+
image_only_indicator=image_only_indicator,
|
339 |
+
)
|
340 |
+
|
341 |
+
return hidden_states
|
342 |
+
|
343 |
+
|
344 |
+
class CrossAttnUpBlockSpatioTemporalPoseCond(nn.Module):
|
345 |
+
def __init__(
|
346 |
+
self,
|
347 |
+
in_channels: int,
|
348 |
+
out_channels: int,
|
349 |
+
prev_output_channel: int,
|
350 |
+
temb_channels: int,
|
351 |
+
resolution_idx: Optional[int] = None,
|
352 |
+
num_layers: int = 1,
|
353 |
+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
354 |
+
resnet_eps: float = 1e-6,
|
355 |
+
num_attention_heads: int = 1,
|
356 |
+
cross_attention_dim: int = 1280,
|
357 |
+
add_upsample: bool = True,
|
358 |
+
):
|
359 |
+
super().__init__()
|
360 |
+
resnets = []
|
361 |
+
attentions = []
|
362 |
+
|
363 |
+
self.has_cross_attention = True
|
364 |
+
self.num_attention_heads = num_attention_heads
|
365 |
+
|
366 |
+
if isinstance(transformer_layers_per_block, int):
|
367 |
+
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
368 |
+
|
369 |
+
for i in range(num_layers):
|
370 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
371 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
372 |
+
|
373 |
+
resnets.append(
|
374 |
+
SpatioTemporalResBlock(
|
375 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
376 |
+
out_channels=out_channels,
|
377 |
+
temb_channels=temb_channels,
|
378 |
+
eps=resnet_eps,
|
379 |
+
)
|
380 |
+
)
|
381 |
+
attentions.append(
|
382 |
+
TransformerSpatioTemporalModelPoseCond(
|
383 |
+
num_attention_heads,
|
384 |
+
out_channels // num_attention_heads,
|
385 |
+
in_channels=out_channels,
|
386 |
+
num_layers=transformer_layers_per_block[i],
|
387 |
+
cross_attention_dim=cross_attention_dim,
|
388 |
+
)
|
389 |
+
)
|
390 |
+
|
391 |
+
self.attentions = nn.ModuleList(attentions)
|
392 |
+
self.resnets = nn.ModuleList(resnets)
|
393 |
+
|
394 |
+
if add_upsample:
|
395 |
+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
|
396 |
+
else:
|
397 |
+
self.upsamplers = None
|
398 |
+
|
399 |
+
self.gradient_checkpointing = False
|
400 |
+
self.resolution_idx = resolution_idx
|
401 |
+
|
402 |
+
def forward(
|
403 |
+
self,
|
404 |
+
hidden_states: torch.FloatTensor,
|
405 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
406 |
+
temb: Optional[torch.FloatTensor] = None,
|
407 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
408 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
409 |
+
pose_feature: Optional[torch.Tensor] = None # [bs, c, frame, h, w]
|
410 |
+
) -> torch.FloatTensor:
|
411 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
412 |
+
# pop res hidden states
|
413 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
414 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
415 |
+
|
416 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
417 |
+
|
418 |
+
if self.training and self.gradient_checkpointing: # TODO
|
419 |
+
|
420 |
+
def create_custom_forward(module, return_dict=None):
|
421 |
+
def custom_forward(*inputs):
|
422 |
+
if return_dict is not None:
|
423 |
+
return module(*inputs, return_dict=return_dict)
|
424 |
+
else:
|
425 |
+
return module(*inputs)
|
426 |
+
|
427 |
+
return custom_forward
|
428 |
+
|
429 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
430 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
431 |
+
create_custom_forward(resnet),
|
432 |
+
hidden_states,
|
433 |
+
temb,
|
434 |
+
image_only_indicator,
|
435 |
+
**ckpt_kwargs,
|
436 |
+
)
|
437 |
+
hidden_states = attn(
|
438 |
+
hidden_states,
|
439 |
+
encoder_hidden_states=encoder_hidden_states,
|
440 |
+
image_only_indicator=image_only_indicator,
|
441 |
+
return_dict=False,
|
442 |
+
)[0]
|
443 |
+
else:
|
444 |
+
hidden_states = resnet(
|
445 |
+
hidden_states,
|
446 |
+
temb,
|
447 |
+
image_only_indicator=image_only_indicator,
|
448 |
+
)
|
449 |
+
hidden_states = attn(
|
450 |
+
hidden_states,
|
451 |
+
encoder_hidden_states=encoder_hidden_states,
|
452 |
+
image_only_indicator=image_only_indicator,
|
453 |
+
pose_feature=pose_feature,
|
454 |
+
return_dict=False,
|
455 |
+
)[0]
|
456 |
+
|
457 |
+
if self.upsamplers is not None:
|
458 |
+
for upsampler in self.upsamplers:
|
459 |
+
hidden_states = upsampler(hidden_states)
|
460 |
+
|
461 |
+
return hidden_states
|
cameractrl/pipelines/pipeline_animation.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
import PIL.Image
|
8 |
+
|
9 |
+
from typing import Callable, List, Optional, Union, Dict
|
10 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
11 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
12 |
+
from diffusers.image_processor import VaeImageProcessor
|
13 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
14 |
+
from diffusers.schedulers import EulerDiscreteScheduler
|
15 |
+
from diffusers.utils import logging
|
16 |
+
from diffusers.utils.torch_utils import randn_tensor
|
17 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
18 |
+
_resize_with_antialiasing,
|
19 |
+
_append_dims,
|
20 |
+
tensor2vid,
|
21 |
+
StableVideoDiffusionPipelineOutput
|
22 |
+
)
|
23 |
+
|
24 |
+
from cameractrl.models.pose_adaptor import CameraPoseEncoder
|
25 |
+
from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29 |
+
|
30 |
+
|
31 |
+
class StableVideoDiffusionPipelinePoseCond(DiffusionPipeline):
|
32 |
+
r"""
|
33 |
+
Pipeline to generate video from an input image using Stable Video Diffusion.
|
34 |
+
|
35 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
36 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
37 |
+
|
38 |
+
Args:
|
39 |
+
vae ([`AutoencoderKL`]):
|
40 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
41 |
+
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
|
42 |
+
Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
|
43 |
+
unet ([`UNetSpatioTemporalConditionModel`]):
|
44 |
+
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
|
45 |
+
scheduler ([`EulerDiscreteScheduler`]):
|
46 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
47 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
48 |
+
A `CLIPImageProcessor` to extract features from generated images.
|
49 |
+
"""
|
50 |
+
|
51 |
+
model_cpu_offload_seq = "image_encoder->unet->vae"
|
52 |
+
_callback_tensor_inputs = ["latents"]
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
vae: AutoencoderKLTemporalDecoder,
|
57 |
+
image_encoder: CLIPVisionModelWithProjection,
|
58 |
+
unet: UNetSpatioTemporalConditionModelPoseCond,
|
59 |
+
scheduler: EulerDiscreteScheduler,
|
60 |
+
feature_extractor: CLIPImageProcessor,
|
61 |
+
pose_encoder: CameraPoseEncoder
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.register_modules(
|
66 |
+
vae=vae,
|
67 |
+
image_encoder=image_encoder,
|
68 |
+
unet=unet,
|
69 |
+
scheduler=scheduler,
|
70 |
+
feature_extractor=feature_extractor,
|
71 |
+
pose_encoder=pose_encoder
|
72 |
+
)
|
73 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
74 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
75 |
+
|
76 |
+
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, do_resize_normalize):
|
77 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
78 |
+
|
79 |
+
if not isinstance(image, torch.Tensor):
|
80 |
+
image = self.image_processor.pil_to_numpy(image)
|
81 |
+
image = self.image_processor.numpy_to_pt(image)
|
82 |
+
|
83 |
+
# We normalize the image before resizing to match with the original implementation.
|
84 |
+
# Then we unnormalize it after resizing.
|
85 |
+
image = image * 2.0 - 1.0
|
86 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
87 |
+
image = (image + 1.0) / 2.0
|
88 |
+
|
89 |
+
# Normalize the image with for CLIP input
|
90 |
+
image = self.feature_extractor(
|
91 |
+
images=image,
|
92 |
+
do_normalize=True,
|
93 |
+
do_center_crop=False,
|
94 |
+
do_resize=False,
|
95 |
+
do_rescale=False,
|
96 |
+
return_tensors="pt",
|
97 |
+
).pixel_values
|
98 |
+
elif do_resize_normalize:
|
99 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
100 |
+
image = (image + 1.0) / 2.0
|
101 |
+
# Normalize the image with for CLIP input
|
102 |
+
image = self.feature_extractor(
|
103 |
+
images=image,
|
104 |
+
do_normalize=True,
|
105 |
+
do_center_crop=False,
|
106 |
+
do_resize=False,
|
107 |
+
do_rescale=False,
|
108 |
+
return_tensors="pt",
|
109 |
+
).pixel_values
|
110 |
+
|
111 |
+
image = image.to(device=device, dtype=dtype)
|
112 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
113 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
114 |
+
|
115 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
116 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
117 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
118 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
119 |
+
|
120 |
+
if do_classifier_free_guidance:
|
121 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
122 |
+
|
123 |
+
# For classifier free guidance, we need to do two forward passes.
|
124 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
125 |
+
# to avoid doing two forward passes
|
126 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
127 |
+
|
128 |
+
return image_embeddings
|
129 |
+
|
130 |
+
def _encode_vae_image(
|
131 |
+
self,
|
132 |
+
image: torch.Tensor,
|
133 |
+
device,
|
134 |
+
num_videos_per_prompt,
|
135 |
+
do_classifier_free_guidance,
|
136 |
+
):
|
137 |
+
image = image.to(device=device)
|
138 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
139 |
+
|
140 |
+
if do_classifier_free_guidance:
|
141 |
+
negative_image_latents = torch.zeros_like(image_latents)
|
142 |
+
|
143 |
+
# For classifier free guidance, we need to do two forward passes.
|
144 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
145 |
+
# to avoid doing two forward passes
|
146 |
+
image_latents = torch.cat([negative_image_latents, image_latents])
|
147 |
+
|
148 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
149 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
150 |
+
|
151 |
+
return image_latents
|
152 |
+
|
153 |
+
def _get_add_time_ids(
|
154 |
+
self,
|
155 |
+
fps,
|
156 |
+
motion_bucket_id,
|
157 |
+
noise_aug_strength,
|
158 |
+
dtype,
|
159 |
+
batch_size,
|
160 |
+
num_videos_per_prompt,
|
161 |
+
do_classifier_free_guidance,
|
162 |
+
):
|
163 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
164 |
+
|
165 |
+
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
|
166 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
167 |
+
|
168 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
169 |
+
raise ValueError(
|
170 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
171 |
+
)
|
172 |
+
|
173 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
174 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
175 |
+
|
176 |
+
if do_classifier_free_guidance:
|
177 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
178 |
+
|
179 |
+
return add_time_ids
|
180 |
+
|
181 |
+
def decode_latents(self, latents, num_frames, decode_chunk_size=14):
|
182 |
+
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
183 |
+
latents = latents.flatten(0, 1)
|
184 |
+
|
185 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
186 |
+
|
187 |
+
accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
|
188 |
+
|
189 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
190 |
+
frames = []
|
191 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
192 |
+
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
|
193 |
+
decode_kwargs = {}
|
194 |
+
if accepts_num_frames:
|
195 |
+
# we only pass num_frames_in if it's expected
|
196 |
+
decode_kwargs["num_frames"] = num_frames_in
|
197 |
+
|
198 |
+
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
|
199 |
+
frames.append(frame)
|
200 |
+
frames = torch.cat(frames, dim=0)
|
201 |
+
|
202 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
203 |
+
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
204 |
+
|
205 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
206 |
+
frames = frames.float()
|
207 |
+
return frames
|
208 |
+
|
209 |
+
def check_inputs(self, image, height, width):
|
210 |
+
if (
|
211 |
+
not isinstance(image, torch.Tensor)
|
212 |
+
and not isinstance(image, PIL.Image.Image)
|
213 |
+
and not isinstance(image, list)
|
214 |
+
):
|
215 |
+
raise ValueError(
|
216 |
+
"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
217 |
+
f" {type(image)}"
|
218 |
+
)
|
219 |
+
|
220 |
+
if height % 8 != 0 or width % 8 != 0:
|
221 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
222 |
+
|
223 |
+
def prepare_latents(
|
224 |
+
self,
|
225 |
+
batch_size,
|
226 |
+
num_frames,
|
227 |
+
num_channels_latents,
|
228 |
+
height,
|
229 |
+
width,
|
230 |
+
dtype,
|
231 |
+
device,
|
232 |
+
generator,
|
233 |
+
latents=None,
|
234 |
+
):
|
235 |
+
shape = (
|
236 |
+
batch_size,
|
237 |
+
num_frames,
|
238 |
+
num_channels_latents // 2,
|
239 |
+
height // self.vae_scale_factor,
|
240 |
+
width // self.vae_scale_factor,
|
241 |
+
)
|
242 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
243 |
+
raise ValueError(
|
244 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
245 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
246 |
+
)
|
247 |
+
|
248 |
+
if latents is None:
|
249 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
250 |
+
else:
|
251 |
+
latents = latents.to(device)
|
252 |
+
|
253 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
254 |
+
latents = latents * self.scheduler.init_noise_sigma
|
255 |
+
return latents
|
256 |
+
|
257 |
+
@property
|
258 |
+
def guidance_scale(self):
|
259 |
+
return self._guidance_scale
|
260 |
+
|
261 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
262 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
263 |
+
# corresponds to doing no classifier free guidance.
|
264 |
+
@property
|
265 |
+
def do_classifier_free_guidance(self):
|
266 |
+
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
267 |
+
|
268 |
+
@property
|
269 |
+
def num_timesteps(self):
|
270 |
+
return self._num_timesteps
|
271 |
+
|
272 |
+
@torch.no_grad()
|
273 |
+
def __call__(
|
274 |
+
self,
|
275 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
276 |
+
pose_embedding: torch.FloatTensor,
|
277 |
+
height: int = 576,
|
278 |
+
width: int = 1024,
|
279 |
+
num_frames: Optional[int] = None,
|
280 |
+
num_inference_steps: int = 25,
|
281 |
+
min_guidance_scale: float = 1.0,
|
282 |
+
max_guidance_scale: float = 3.0,
|
283 |
+
fps: int = 7,
|
284 |
+
motion_bucket_id: int = 127,
|
285 |
+
noise_aug_strength: int = 0.02,
|
286 |
+
do_resize_normalize: bool = True,
|
287 |
+
do_image_process: bool = False,
|
288 |
+
decode_chunk_size: Optional[int] = None,
|
289 |
+
num_videos_per_prompt: Optional[int] = 1,
|
290 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
291 |
+
latents: Optional[torch.FloatTensor] = None,
|
292 |
+
output_type: Optional[str] = "pil",
|
293 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
294 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
295 |
+
return_dict: bool = True,
|
296 |
+
):
|
297 |
+
r"""
|
298 |
+
The call function to the pipeline for generation.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
302 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
303 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
304 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
305 |
+
The height in pixels of the generated image.
|
306 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
307 |
+
The width in pixels of the generated image.
|
308 |
+
num_frames (`int`, *optional*):
|
309 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
|
310 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
311 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
312 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
313 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
314 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
315 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
316 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
317 |
+
fps (`int`, *optional*, defaults to 7):
|
318 |
+
Frames per second. The rate at which the generated images shall be exported to a video after generation.
|
319 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
320 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
321 |
+
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
322 |
+
noise_aug_strength (`int`, *optional*, defaults to 0.02):
|
323 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
324 |
+
decode_chunk_size (`int`, *optional*):
|
325 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
326 |
+
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
|
327 |
+
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
328 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
329 |
+
The number of images to generate per prompt.
|
330 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
331 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
332 |
+
generation deterministic.
|
333 |
+
latents (`torch.FloatTensor`, *optional*):
|
334 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
335 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
336 |
+
tensor is generated by sampling using the supplied random `generator`.
|
337 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
338 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
339 |
+
callback_on_step_end (`Callable`, *optional*):
|
340 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
341 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
342 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
343 |
+
`callback_on_step_end_tensor_inputs`.
|
344 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
345 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
346 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
347 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
348 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
349 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
350 |
+
plain tuple.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
354 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
355 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
356 |
+
|
357 |
+
Examples:
|
358 |
+
|
359 |
+
```py
|
360 |
+
from diffusers import StableVideoDiffusionPipeline
|
361 |
+
from diffusers.utils import load_image, export_to_video
|
362 |
+
|
363 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
364 |
+
pipe.to("cuda")
|
365 |
+
|
366 |
+
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
367 |
+
image = image.resize((1024, 576))
|
368 |
+
|
369 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
370 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
371 |
+
```
|
372 |
+
"""
|
373 |
+
# 0. Default height and width to unet
|
374 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
375 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
376 |
+
|
377 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
378 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
379 |
+
|
380 |
+
# 1. Check inputs. Raise error if not correct
|
381 |
+
self.check_inputs(image, height, width)
|
382 |
+
|
383 |
+
# 2. Define call parameters
|
384 |
+
if isinstance(image, PIL.Image.Image):
|
385 |
+
batch_size = 1
|
386 |
+
elif isinstance(image, list):
|
387 |
+
batch_size = len(image)
|
388 |
+
else:
|
389 |
+
batch_size = image.shape[0]
|
390 |
+
device = pose_embedding.device
|
391 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
392 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
393 |
+
# corresponds to doing no classifier free guidance.
|
394 |
+
do_classifier_free_guidance = max_guidance_scale > 1.0
|
395 |
+
|
396 |
+
# 3. Encode input image
|
397 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance, do_resize_normalize=do_resize_normalize)
|
398 |
+
|
399 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
400 |
+
# is why it is reduced here.
|
401 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
402 |
+
fps = fps - 1
|
403 |
+
|
404 |
+
# 4. Encode input image using VAE
|
405 |
+
if do_image_process:
|
406 |
+
image = self.image_processor.preprocess(image, height=height, width=width).to(image_embeddings.device)
|
407 |
+
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
408 |
+
image = image + noise_aug_strength * noise
|
409 |
+
|
410 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
411 |
+
if needs_upcasting:
|
412 |
+
self.vae.to(dtype=torch.float32)
|
413 |
+
|
414 |
+
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
415 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
416 |
+
|
417 |
+
# cast back to fp16 if needed
|
418 |
+
if needs_upcasting:
|
419 |
+
self.vae.to(dtype=torch.float16)
|
420 |
+
|
421 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
422 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
423 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
424 |
+
|
425 |
+
# 5. Get Added Time IDs
|
426 |
+
added_time_ids = self._get_add_time_ids(
|
427 |
+
fps,
|
428 |
+
motion_bucket_id,
|
429 |
+
noise_aug_strength,
|
430 |
+
image_embeddings.dtype,
|
431 |
+
batch_size,
|
432 |
+
num_videos_per_prompt,
|
433 |
+
do_classifier_free_guidance,
|
434 |
+
)
|
435 |
+
added_time_ids = added_time_ids.to(device)
|
436 |
+
|
437 |
+
# 4. Prepare timesteps
|
438 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
439 |
+
timesteps = self.scheduler.timesteps
|
440 |
+
|
441 |
+
# 5. Prepare latent variables
|
442 |
+
num_channels_latents = self.unet.config.in_channels
|
443 |
+
latents = self.prepare_latents(
|
444 |
+
batch_size * num_videos_per_prompt,
|
445 |
+
num_frames,
|
446 |
+
num_channels_latents,
|
447 |
+
height,
|
448 |
+
width,
|
449 |
+
image_embeddings.dtype,
|
450 |
+
device,
|
451 |
+
generator,
|
452 |
+
latents,
|
453 |
+
) # [bs, frame, c, h, w]
|
454 |
+
|
455 |
+
# 7. Prepare guidance scale
|
456 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
457 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
458 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
459 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim) # [bs, frame, 1, 1, 1]
|
460 |
+
|
461 |
+
self._guidance_scale = guidance_scale
|
462 |
+
|
463 |
+
# 8. Prepare pose features
|
464 |
+
assert pose_embedding.ndim == 5 # [b, f, c, h, w]
|
465 |
+
pose_features = self.pose_encoder(pose_embedding) # list of [b, c, f, h, w]
|
466 |
+
pose_features = [torch.cat([x, x], dim=0) for x in pose_features] if do_classifier_free_guidance else pose_features
|
467 |
+
|
468 |
+
# 9. Denoising loop
|
469 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
470 |
+
self._num_timesteps = len(timesteps)
|
471 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
472 |
+
for i, t in enumerate(timesteps):
|
473 |
+
# expand the latents if we are doing classifier free guidance
|
474 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
475 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
476 |
+
|
477 |
+
# Concatenate image_latents over channels dimention
|
478 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
479 |
+
|
480 |
+
# predict the noise residual
|
481 |
+
noise_pred = self.unet(
|
482 |
+
latent_model_input,
|
483 |
+
t,
|
484 |
+
encoder_hidden_states=image_embeddings,
|
485 |
+
added_time_ids=added_time_ids,
|
486 |
+
pose_features=pose_features,
|
487 |
+
return_dict=False,
|
488 |
+
)[0]
|
489 |
+
|
490 |
+
# perform guidance
|
491 |
+
if do_classifier_free_guidance:
|
492 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
493 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
494 |
+
|
495 |
+
# compute the previous noisy sample x_t -> x_t-1
|
496 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
497 |
+
|
498 |
+
if callback_on_step_end is not None:
|
499 |
+
callback_kwargs = {}
|
500 |
+
for k in callback_on_step_end_tensor_inputs:
|
501 |
+
callback_kwargs[k] = locals()[k]
|
502 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
503 |
+
|
504 |
+
latents = callback_outputs.pop("latents", latents)
|
505 |
+
|
506 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
507 |
+
progress_bar.update()
|
508 |
+
|
509 |
+
if not output_type == "latent":
|
510 |
+
# cast back to fp16 if needed
|
511 |
+
if needs_upcasting:
|
512 |
+
self.vae.to(dtype=torch.float16)
|
513 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size) # [b, c, f, h, w]
|
514 |
+
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
515 |
+
else:
|
516 |
+
frames = latents
|
517 |
+
|
518 |
+
self.maybe_free_model_hooks()
|
519 |
+
|
520 |
+
if not return_dict:
|
521 |
+
return frames
|
522 |
+
|
523 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
cameractrl/utils/convert_from_ckpt.py
ADDED
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
from transformers import CLIPTextModel
|
19 |
+
|
20 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
21 |
+
"""
|
22 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
23 |
+
"""
|
24 |
+
if n_shave_prefix_segments >= 0:
|
25 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
26 |
+
else:
|
27 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
28 |
+
|
29 |
+
|
30 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
31 |
+
"""
|
32 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
33 |
+
"""
|
34 |
+
mapping = []
|
35 |
+
for old_item in old_list:
|
36 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
37 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
38 |
+
|
39 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
40 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
41 |
+
|
42 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
43 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
44 |
+
|
45 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
46 |
+
|
47 |
+
mapping.append({"old": old_item, "new": new_item})
|
48 |
+
|
49 |
+
return mapping
|
50 |
+
|
51 |
+
|
52 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
53 |
+
"""
|
54 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
55 |
+
"""
|
56 |
+
mapping = []
|
57 |
+
for old_item in old_list:
|
58 |
+
new_item = old_item
|
59 |
+
|
60 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
61 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
62 |
+
|
63 |
+
mapping.append({"old": old_item, "new": new_item})
|
64 |
+
|
65 |
+
return mapping
|
66 |
+
|
67 |
+
|
68 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
69 |
+
"""
|
70 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
71 |
+
"""
|
72 |
+
mapping = []
|
73 |
+
for old_item in old_list:
|
74 |
+
new_item = old_item
|
75 |
+
|
76 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
77 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
78 |
+
|
79 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
80 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
81 |
+
|
82 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
83 |
+
|
84 |
+
mapping.append({"old": old_item, "new": new_item})
|
85 |
+
|
86 |
+
return mapping
|
87 |
+
|
88 |
+
|
89 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
90 |
+
"""
|
91 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
92 |
+
"""
|
93 |
+
mapping = []
|
94 |
+
for old_item in old_list:
|
95 |
+
new_item = old_item
|
96 |
+
|
97 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
98 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
99 |
+
|
100 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
101 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
102 |
+
|
103 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
104 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
105 |
+
|
106 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
107 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
108 |
+
|
109 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
110 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
111 |
+
|
112 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
113 |
+
|
114 |
+
mapping.append({"old": old_item, "new": new_item})
|
115 |
+
|
116 |
+
return mapping
|
117 |
+
|
118 |
+
|
119 |
+
def assign_to_checkpoint(
|
120 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
121 |
+
):
|
122 |
+
"""
|
123 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
124 |
+
attention layers, and takes into account additional replacements that may arise.
|
125 |
+
|
126 |
+
Assigns the weights to the new checkpoint.
|
127 |
+
"""
|
128 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
129 |
+
|
130 |
+
# Splits the attention layers into three variables.
|
131 |
+
if attention_paths_to_split is not None:
|
132 |
+
for path, path_map in attention_paths_to_split.items():
|
133 |
+
old_tensor = old_checkpoint[path]
|
134 |
+
channels = old_tensor.shape[0] // 3
|
135 |
+
|
136 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
137 |
+
|
138 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
139 |
+
|
140 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
141 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
142 |
+
|
143 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
144 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
145 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
146 |
+
|
147 |
+
for path in paths:
|
148 |
+
new_path = path["new"]
|
149 |
+
|
150 |
+
# These have already been assigned
|
151 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
152 |
+
continue
|
153 |
+
|
154 |
+
# Global renaming happens here
|
155 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
156 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
157 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
158 |
+
|
159 |
+
if additional_replacements is not None:
|
160 |
+
for replacement in additional_replacements:
|
161 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
162 |
+
|
163 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
164 |
+
if "proj_attn.weight" in new_path:
|
165 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
166 |
+
else:
|
167 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
168 |
+
|
169 |
+
|
170 |
+
def conv_attn_to_linear(checkpoint):
|
171 |
+
keys = list(checkpoint.keys())
|
172 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
173 |
+
for key in keys:
|
174 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
175 |
+
if checkpoint[key].ndim > 2:
|
176 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
177 |
+
elif "proj_attn.weight" in key:
|
178 |
+
if checkpoint[key].ndim > 2:
|
179 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
180 |
+
|
181 |
+
|
182 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
183 |
+
"""
|
184 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
185 |
+
"""
|
186 |
+
|
187 |
+
# extract state_dict for UNet
|
188 |
+
unet_state_dict = {}
|
189 |
+
keys = list(checkpoint.keys())
|
190 |
+
|
191 |
+
if controlnet:
|
192 |
+
unet_key = "control_model."
|
193 |
+
else:
|
194 |
+
unet_key = "model.diffusion_model."
|
195 |
+
|
196 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
197 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
198 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
199 |
+
print(
|
200 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
201 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
202 |
+
)
|
203 |
+
for key in keys:
|
204 |
+
if key.startswith("model.diffusion_model"):
|
205 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
206 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
207 |
+
else:
|
208 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
209 |
+
print(
|
210 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
211 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
212 |
+
)
|
213 |
+
|
214 |
+
for key in keys:
|
215 |
+
if key.startswith(unet_key):
|
216 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
217 |
+
|
218 |
+
new_checkpoint = {}
|
219 |
+
|
220 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
221 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
222 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
223 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
224 |
+
|
225 |
+
if config["class_embed_type"] is None:
|
226 |
+
# No parameters to port
|
227 |
+
...
|
228 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
229 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
230 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
231 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
232 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
233 |
+
else:
|
234 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
235 |
+
|
236 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
237 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
238 |
+
|
239 |
+
if not controlnet:
|
240 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
241 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
242 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
243 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
244 |
+
|
245 |
+
# Retrieves the keys for the input blocks only
|
246 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
247 |
+
input_blocks = {
|
248 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
249 |
+
for layer_id in range(num_input_blocks)
|
250 |
+
}
|
251 |
+
|
252 |
+
# Retrieves the keys for the middle blocks only
|
253 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
254 |
+
middle_blocks = {
|
255 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
256 |
+
for layer_id in range(num_middle_blocks)
|
257 |
+
}
|
258 |
+
|
259 |
+
# Retrieves the keys for the output blocks only
|
260 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
261 |
+
output_blocks = {
|
262 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
263 |
+
for layer_id in range(num_output_blocks)
|
264 |
+
}
|
265 |
+
|
266 |
+
for i in range(1, num_input_blocks):
|
267 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
268 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
269 |
+
|
270 |
+
resnets = [
|
271 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
272 |
+
]
|
273 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
274 |
+
|
275 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
276 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
277 |
+
f"input_blocks.{i}.0.op.weight"
|
278 |
+
)
|
279 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
280 |
+
f"input_blocks.{i}.0.op.bias"
|
281 |
+
)
|
282 |
+
|
283 |
+
paths = renew_resnet_paths(resnets)
|
284 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
285 |
+
assign_to_checkpoint(
|
286 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
287 |
+
)
|
288 |
+
|
289 |
+
if len(attentions):
|
290 |
+
paths = renew_attention_paths(attentions)
|
291 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
292 |
+
assign_to_checkpoint(
|
293 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
294 |
+
)
|
295 |
+
|
296 |
+
resnet_0 = middle_blocks[0]
|
297 |
+
attentions = middle_blocks[1]
|
298 |
+
resnet_1 = middle_blocks[2]
|
299 |
+
|
300 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
301 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
302 |
+
|
303 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
304 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
305 |
+
|
306 |
+
attentions_paths = renew_attention_paths(attentions)
|
307 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
308 |
+
assign_to_checkpoint(
|
309 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
310 |
+
)
|
311 |
+
|
312 |
+
for i in range(num_output_blocks):
|
313 |
+
block_id = i // (config["layers_per_block"] + 1)
|
314 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
315 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
316 |
+
output_block_list = {}
|
317 |
+
|
318 |
+
for layer in output_block_layers:
|
319 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
320 |
+
if layer_id in output_block_list:
|
321 |
+
output_block_list[layer_id].append(layer_name)
|
322 |
+
else:
|
323 |
+
output_block_list[layer_id] = [layer_name]
|
324 |
+
|
325 |
+
if len(output_block_list) > 1:
|
326 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
327 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
328 |
+
|
329 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
330 |
+
paths = renew_resnet_paths(resnets)
|
331 |
+
|
332 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
333 |
+
assign_to_checkpoint(
|
334 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
335 |
+
)
|
336 |
+
|
337 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
338 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
339 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
340 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
341 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
342 |
+
]
|
343 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
344 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
345 |
+
]
|
346 |
+
|
347 |
+
# Clear attentions as they have been attributed above.
|
348 |
+
if len(attentions) == 2:
|
349 |
+
attentions = []
|
350 |
+
|
351 |
+
if len(attentions):
|
352 |
+
paths = renew_attention_paths(attentions)
|
353 |
+
meta_path = {
|
354 |
+
"old": f"output_blocks.{i}.1",
|
355 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
356 |
+
}
|
357 |
+
assign_to_checkpoint(
|
358 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
359 |
+
)
|
360 |
+
else:
|
361 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
362 |
+
for path in resnet_0_paths:
|
363 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
364 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
365 |
+
|
366 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
367 |
+
|
368 |
+
if controlnet:
|
369 |
+
# conditioning embedding
|
370 |
+
|
371 |
+
orig_index = 0
|
372 |
+
|
373 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
374 |
+
f"input_hint_block.{orig_index}.weight"
|
375 |
+
)
|
376 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
377 |
+
f"input_hint_block.{orig_index}.bias"
|
378 |
+
)
|
379 |
+
|
380 |
+
orig_index += 2
|
381 |
+
|
382 |
+
diffusers_index = 0
|
383 |
+
|
384 |
+
while diffusers_index < 6:
|
385 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
386 |
+
f"input_hint_block.{orig_index}.weight"
|
387 |
+
)
|
388 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
389 |
+
f"input_hint_block.{orig_index}.bias"
|
390 |
+
)
|
391 |
+
diffusers_index += 1
|
392 |
+
orig_index += 2
|
393 |
+
|
394 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
395 |
+
f"input_hint_block.{orig_index}.weight"
|
396 |
+
)
|
397 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
398 |
+
f"input_hint_block.{orig_index}.bias"
|
399 |
+
)
|
400 |
+
|
401 |
+
# down blocks
|
402 |
+
for i in range(num_input_blocks):
|
403 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
404 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
405 |
+
|
406 |
+
# mid block
|
407 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
408 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
409 |
+
|
410 |
+
return new_checkpoint
|
411 |
+
|
412 |
+
|
413 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
414 |
+
# extract state dict for VAE
|
415 |
+
vae_state_dict = {}
|
416 |
+
keys = list(checkpoint.keys())
|
417 |
+
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
418 |
+
for key in keys:
|
419 |
+
if key.startswith(vae_key):
|
420 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
421 |
+
|
422 |
+
new_checkpoint = {}
|
423 |
+
|
424 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
425 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
426 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
427 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
428 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
429 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
430 |
+
|
431 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
432 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
433 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
434 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
435 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
436 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
437 |
+
|
438 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
439 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
440 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
441 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
442 |
+
|
443 |
+
# Retrieves the keys for the encoder down blocks only
|
444 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
445 |
+
down_blocks = {
|
446 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
447 |
+
}
|
448 |
+
|
449 |
+
# Retrieves the keys for the decoder up blocks only
|
450 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
451 |
+
up_blocks = {
|
452 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
453 |
+
}
|
454 |
+
|
455 |
+
for i in range(num_down_blocks):
|
456 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
457 |
+
|
458 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
459 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
460 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
461 |
+
)
|
462 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
463 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
464 |
+
)
|
465 |
+
|
466 |
+
paths = renew_vae_resnet_paths(resnets)
|
467 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
468 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
469 |
+
|
470 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
471 |
+
num_mid_res_blocks = 2
|
472 |
+
for i in range(1, num_mid_res_blocks + 1):
|
473 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
474 |
+
|
475 |
+
paths = renew_vae_resnet_paths(resnets)
|
476 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
477 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
478 |
+
|
479 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
480 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
481 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
482 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
483 |
+
conv_attn_to_linear(new_checkpoint)
|
484 |
+
|
485 |
+
for i in range(num_up_blocks):
|
486 |
+
block_id = num_up_blocks - 1 - i
|
487 |
+
resnets = [
|
488 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
489 |
+
]
|
490 |
+
|
491 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
492 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
493 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
494 |
+
]
|
495 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
496 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
497 |
+
]
|
498 |
+
|
499 |
+
paths = renew_vae_resnet_paths(resnets)
|
500 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
501 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
502 |
+
|
503 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
504 |
+
num_mid_res_blocks = 2
|
505 |
+
for i in range(1, num_mid_res_blocks + 1):
|
506 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
507 |
+
|
508 |
+
paths = renew_vae_resnet_paths(resnets)
|
509 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
510 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
511 |
+
|
512 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
513 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
514 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
515 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
516 |
+
conv_attn_to_linear(new_checkpoint)
|
517 |
+
return new_checkpoint
|
518 |
+
|
519 |
+
|
520 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
521 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
522 |
+
keys = list(checkpoint.keys())
|
523 |
+
|
524 |
+
text_model_dict = {}
|
525 |
+
|
526 |
+
for key in keys:
|
527 |
+
if key.startswith("cond_stage_model.transformer"):
|
528 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
529 |
+
|
530 |
+
text_model.load_state_dict(text_model_dict)
|
531 |
+
|
532 |
+
return text_model
|
533 |
+
|
534 |
+
|
535 |
+
textenc_conversion_lst = [
|
536 |
+
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
537 |
+
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
538 |
+
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
539 |
+
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
540 |
+
]
|
541 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
542 |
+
|
543 |
+
textenc_transformer_conversion_lst = [
|
544 |
+
# (stable-diffusion, HF Diffusers)
|
545 |
+
("resblocks.", "text_model.encoder.layers."),
|
546 |
+
("ln_1", "layer_norm1"),
|
547 |
+
("ln_2", "layer_norm2"),
|
548 |
+
(".c_fc.", ".fc1."),
|
549 |
+
(".c_proj.", ".fc2."),
|
550 |
+
(".attn", ".self_attn"),
|
551 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
552 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
553 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
554 |
+
]
|
555 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
556 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
cameractrl/utils/convert_lora_safetensor_to_diffusers.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Conversion script for the LoRA's safetensors checkpoints. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from diffusers import StableDiffusionPipeline
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
|
29 |
+
# directly update weight in diffusers model
|
30 |
+
for key in state_dict:
|
31 |
+
# only process lora down key
|
32 |
+
if "up." in key: continue
|
33 |
+
|
34 |
+
up_key = key.replace(".down.", ".up.")
|
35 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
36 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
37 |
+
layer_infos = model_key.split(".")[:-1]
|
38 |
+
|
39 |
+
curr_layer = pipeline.unet
|
40 |
+
while len(layer_infos) > 0:
|
41 |
+
temp_name = layer_infos.pop(0)
|
42 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
43 |
+
|
44 |
+
weight_down = state_dict[key]
|
45 |
+
weight_up = state_dict[up_key]
|
46 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
47 |
+
|
48 |
+
return pipeline
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
53 |
+
# load base model
|
54 |
+
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
55 |
+
|
56 |
+
# load LoRA weight from .safetensors
|
57 |
+
# state_dict = load_file(checkpoint_path)
|
58 |
+
|
59 |
+
visited = []
|
60 |
+
|
61 |
+
# directly update weight in diffusers model
|
62 |
+
for key in state_dict:
|
63 |
+
# it is suggested to print out the key, it usually will be something like below
|
64 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
65 |
+
|
66 |
+
# as we have set the alpha beforehand, so just skip
|
67 |
+
if ".alpha" in key or key in visited:
|
68 |
+
continue
|
69 |
+
|
70 |
+
if "text" in key:
|
71 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
72 |
+
curr_layer = pipeline.text_encoder
|
73 |
+
else:
|
74 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
75 |
+
curr_layer = pipeline.unet
|
76 |
+
|
77 |
+
# find the target layer
|
78 |
+
temp_name = layer_infos.pop(0)
|
79 |
+
while len(layer_infos) > -1:
|
80 |
+
try:
|
81 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
82 |
+
if len(layer_infos) > 0:
|
83 |
+
temp_name = layer_infos.pop(0)
|
84 |
+
elif len(layer_infos) == 0:
|
85 |
+
break
|
86 |
+
except Exception:
|
87 |
+
if len(temp_name) > 0:
|
88 |
+
temp_name += "_" + layer_infos.pop(0)
|
89 |
+
else:
|
90 |
+
temp_name = layer_infos.pop(0)
|
91 |
+
|
92 |
+
pair_keys = []
|
93 |
+
if "lora_down" in key:
|
94 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
95 |
+
pair_keys.append(key)
|
96 |
+
else:
|
97 |
+
pair_keys.append(key)
|
98 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
99 |
+
|
100 |
+
# update weight
|
101 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
102 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
103 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
104 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
105 |
+
else:
|
106 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
107 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
108 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
109 |
+
|
110 |
+
# update visited list
|
111 |
+
for item in pair_keys:
|
112 |
+
visited.append(item)
|
113 |
+
|
114 |
+
return pipeline
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
125 |
+
)
|
126 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
127 |
+
parser.add_argument(
|
128 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--lora_prefix_text_encoder",
|
132 |
+
default="lora_te",
|
133 |
+
type=str,
|
134 |
+
help="The prefix of text encoder weight in safetensors",
|
135 |
+
)
|
136 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
137 |
+
parser.add_argument(
|
138 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
139 |
+
)
|
140 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
141 |
+
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
base_model_path = args.base_model_path
|
145 |
+
checkpoint_path = args.checkpoint_path
|
146 |
+
dump_path = args.dump_path
|
147 |
+
lora_prefix_unet = args.lora_prefix_unet
|
148 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
149 |
+
alpha = args.alpha
|
150 |
+
|
151 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
152 |
+
|
153 |
+
pipe = pipe.to(args.device)
|
154 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
cameractrl/utils/util.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import functools
|
3 |
+
import logging
|
4 |
+
import sys
|
5 |
+
import imageio
|
6 |
+
import atexit
|
7 |
+
import importlib
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
import numpy as np
|
11 |
+
from termcolor import colored
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
|
16 |
+
def instantiate_from_config(config, **additional_kwargs):
|
17 |
+
if not "target" in config:
|
18 |
+
if config == '__is_first_stage__':
|
19 |
+
return None
|
20 |
+
elif config == "__is_unconditional__":
|
21 |
+
return None
|
22 |
+
raise KeyError("Expected key `target` to instantiate.")
|
23 |
+
|
24 |
+
additional_kwargs.update(config.get("kwargs", dict()))
|
25 |
+
return get_obj_from_str(config["target"])(**additional_kwargs)
|
26 |
+
|
27 |
+
|
28 |
+
def get_obj_from_str(string, reload=False):
|
29 |
+
module, cls = string.rsplit(".", 1)
|
30 |
+
if reload:
|
31 |
+
module_imp = importlib.import_module(module)
|
32 |
+
importlib.reload(module_imp)
|
33 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
34 |
+
|
35 |
+
|
36 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
37 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
38 |
+
outputs = []
|
39 |
+
for x in videos:
|
40 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
41 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
42 |
+
if rescale:
|
43 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
44 |
+
x = (x * 255).numpy().astype(np.uint8)
|
45 |
+
outputs.append(x)
|
46 |
+
|
47 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
48 |
+
imageio.mimsave(path, outputs, fps=fps)
|
49 |
+
|
50 |
+
|
51 |
+
# Logger utils are copied from detectron2
|
52 |
+
class _ColorfulFormatter(logging.Formatter):
|
53 |
+
def __init__(self, *args, **kwargs):
|
54 |
+
self._root_name = kwargs.pop("root_name") + "."
|
55 |
+
self._abbrev_name = kwargs.pop("abbrev_name", "")
|
56 |
+
if len(self._abbrev_name):
|
57 |
+
self._abbrev_name = self._abbrev_name + "."
|
58 |
+
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
|
59 |
+
|
60 |
+
def formatMessage(self, record):
|
61 |
+
record.name = record.name.replace(self._root_name, self._abbrev_name)
|
62 |
+
log = super(_ColorfulFormatter, self).formatMessage(record)
|
63 |
+
if record.levelno == logging.WARNING:
|
64 |
+
prefix = colored("WARNING", "red", attrs=["blink"])
|
65 |
+
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
66 |
+
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
67 |
+
else:
|
68 |
+
return log
|
69 |
+
return prefix + " " + log
|
70 |
+
|
71 |
+
|
72 |
+
# cache the opened file object, so that different calls to `setup_logger`
|
73 |
+
# with the same file name can safely write to the same file.
|
74 |
+
@functools.lru_cache(maxsize=None)
|
75 |
+
def _cached_log_stream(filename):
|
76 |
+
# use 1K buffer if writing to cloud storage
|
77 |
+
io = open(filename, "a", buffering=1024 if "://" in filename else -1)
|
78 |
+
atexit.register(io.close)
|
79 |
+
return io
|
80 |
+
|
81 |
+
@functools.lru_cache()
|
82 |
+
def setup_logger(output, distributed_rank, color=True, name='AnimateDiff', abbrev_name=None):
|
83 |
+
logger = logging.getLogger(name)
|
84 |
+
logger.setLevel(logging.DEBUG)
|
85 |
+
logger.propagate = False
|
86 |
+
|
87 |
+
if abbrev_name is None:
|
88 |
+
abbrev_name = 'AD'
|
89 |
+
plain_formatter = logging.Formatter(
|
90 |
+
"[%(asctime)s] %(name)s:%(lineno)d %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
|
91 |
+
)
|
92 |
+
|
93 |
+
# stdout logging: master only
|
94 |
+
if distributed_rank == 0:
|
95 |
+
ch = logging.StreamHandler(stream=sys.stdout)
|
96 |
+
ch.setLevel(logging.DEBUG)
|
97 |
+
if color:
|
98 |
+
formatter = _ColorfulFormatter(
|
99 |
+
colored("[%(asctime)s %(name)s:%(lineno)d]: ", "green") + "%(message)s",
|
100 |
+
datefmt="%m/%d %H:%M:%S",
|
101 |
+
root_name=name,
|
102 |
+
abbrev_name=str(abbrev_name),
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
formatter = plain_formatter
|
106 |
+
ch.setFormatter(formatter)
|
107 |
+
logger.addHandler(ch)
|
108 |
+
|
109 |
+
# file logging: all workers
|
110 |
+
if output is not None:
|
111 |
+
if output.endswith(".txt") or output.endswith(".log"):
|
112 |
+
filename = output
|
113 |
+
else:
|
114 |
+
filename = os.path.join(output, "log.txt")
|
115 |
+
if distributed_rank > 0:
|
116 |
+
filename = filename + ".rank{}".format(distributed_rank)
|
117 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
118 |
+
|
119 |
+
fh = logging.StreamHandler(_cached_log_stream(filename))
|
120 |
+
fh.setLevel(logging.DEBUG)
|
121 |
+
fh.setFormatter(plain_formatter)
|
122 |
+
logger.addHandler(fh)
|
123 |
+
|
124 |
+
return logger
|
125 |
+
|
126 |
+
|
127 |
+
def format_time(elapsed_time):
|
128 |
+
# Time thresholds
|
129 |
+
minute = 60
|
130 |
+
hour = 60 * minute
|
131 |
+
day = 24 * hour
|
132 |
+
|
133 |
+
days, remainder = divmod(elapsed_time, day)
|
134 |
+
hours, remainder = divmod(remainder, hour)
|
135 |
+
minutes, seconds = divmod(remainder, minute)
|
136 |
+
|
137 |
+
formatted_time = ""
|
138 |
+
|
139 |
+
if days > 0:
|
140 |
+
formatted_time += f"{int(days)} days "
|
141 |
+
if hours > 0:
|
142 |
+
formatted_time += f"{int(hours)} hours "
|
143 |
+
if minutes > 0:
|
144 |
+
formatted_time += f"{int(minutes)} minutes "
|
145 |
+
if seconds > 0:
|
146 |
+
formatted_time += f"{seconds:.2f} seconds"
|
147 |
+
|
148 |
+
return formatted_time.strip()
|
configs/train_cameractrl/svd_320_576_cameractrl.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "output/cameractrl_model"
|
2 |
+
pretrained_model_path: "/mnt/petrelfs/liangzhengyang.d/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid/snapshots/2586584918a955489b599d4dc76b6bb3fdb3fbb2"
|
3 |
+
unet_subfolder: "unet"
|
4 |
+
down_block_types: ['CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'DownBlockSpatioTemporal']
|
5 |
+
up_block_types: ['UpBlockSpatioTemporal', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond']
|
6 |
+
|
7 |
+
train_data:
|
8 |
+
root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
|
9 |
+
annotation_json: "annotations/train.json"
|
10 |
+
sample_stride: 8
|
11 |
+
sample_n_frames: 14
|
12 |
+
relative_pose: true
|
13 |
+
zero_t_first_frame: true
|
14 |
+
sample_size: [320, 576]
|
15 |
+
rescale_fxy: true
|
16 |
+
shuffle_frames: false
|
17 |
+
use_flip: false
|
18 |
+
|
19 |
+
validation_data:
|
20 |
+
root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
|
21 |
+
annotation_json: "annotations/validation.json"
|
22 |
+
sample_stride: 8
|
23 |
+
sample_n_frames: 14
|
24 |
+
relative_pose: true
|
25 |
+
zero_t_first_frame: true
|
26 |
+
sample_size: [320, 576]
|
27 |
+
rescale_fxy: true
|
28 |
+
shuffle_frames: false
|
29 |
+
use_flip: false
|
30 |
+
return_clip_name: true
|
31 |
+
|
32 |
+
random_null_image_ratio: 0.15
|
33 |
+
|
34 |
+
pose_encoder_kwargs:
|
35 |
+
downscale_factor: 8
|
36 |
+
channels: [320, 640, 1280, 1280]
|
37 |
+
nums_rb: 2
|
38 |
+
cin: 384
|
39 |
+
ksize: 1
|
40 |
+
sk: true
|
41 |
+
use_conv: false
|
42 |
+
compression_factor: 1
|
43 |
+
temporal_attention_nhead: 8
|
44 |
+
attention_block_types: ["Temporal_Self", ]
|
45 |
+
temporal_position_encoding: true
|
46 |
+
temporal_position_encoding_max_len: 14
|
47 |
+
|
48 |
+
attention_processor_kwargs:
|
49 |
+
add_spatial: false
|
50 |
+
add_temporal: true
|
51 |
+
attn_processor_name: 'attn1'
|
52 |
+
pose_feature_dimensions: [320, 640, 1280, 1280]
|
53 |
+
query_condition: true
|
54 |
+
key_value_condition: true
|
55 |
+
scale: 1.0
|
56 |
+
|
57 |
+
do_sanity_check: true
|
58 |
+
sample_before_training: false
|
59 |
+
|
60 |
+
max_train_epoch: -1
|
61 |
+
max_train_steps: 50000
|
62 |
+
validation_steps: 2500
|
63 |
+
validation_steps_tuple: [500, ]
|
64 |
+
|
65 |
+
learning_rate: 3.e-5
|
66 |
+
|
67 |
+
P_mean: 0.7
|
68 |
+
P_std: 1.6
|
69 |
+
condition_image_noise_mean: -3.0
|
70 |
+
condition_image_noise_std: 0.5
|
71 |
+
sample_latent: true
|
72 |
+
first_image_cond: true
|
73 |
+
|
74 |
+
num_inference_steps: 25
|
75 |
+
min_guidance_scale: 1.0
|
76 |
+
max_guidance_scale: 3.0
|
77 |
+
|
78 |
+
num_workers: 8
|
79 |
+
train_batch_size: 1
|
80 |
+
checkpointing_epochs: -1
|
81 |
+
checkpointing_steps: 10000
|
82 |
+
|
83 |
+
mixed_precision_training: false
|
84 |
+
enable_xformers_memory_efficient_attention: true
|
85 |
+
|
86 |
+
global_seed: 42
|
87 |
+
logger_interval: 10
|
configs/train_cameractrl/svdxt_320_576_cameractrl.yaml
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "output/cameractrl_model"
|
2 |
+
pretrained_model_path: "/mnt/petrelfs/liangzhengyang.d/.cache/huggingface/hub/models--stabilityai--stable-video-diffusion-img2vid-xt/snapshots/4420c0886aad9930787308c62d9dd8befd4900f6"
|
3 |
+
unet_subfolder: "unet"
|
4 |
+
down_block_types: ['CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'CrossAttnDownBlockSpatioTemporalPoseCond', 'DownBlockSpatioTemporal']
|
5 |
+
up_block_types: ['UpBlockSpatioTemporal', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond', 'CrossAttnUpBlockSpatioTemporalPoseCond']
|
6 |
+
|
7 |
+
train_data:
|
8 |
+
root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
|
9 |
+
annotation_json: "annotations/train.json"
|
10 |
+
sample_stride: 5
|
11 |
+
sample_n_frames: 25
|
12 |
+
relative_pose: true
|
13 |
+
zero_t_first_frame: true
|
14 |
+
sample_size: [320, 576]
|
15 |
+
rescale_fxy: true
|
16 |
+
shuffle_frames: false
|
17 |
+
use_flip: false
|
18 |
+
|
19 |
+
validation_data:
|
20 |
+
root_path: "/mnt/petrelfs/share_data/hehao/datasets/RealEstate10k"
|
21 |
+
annotation_json: "annotations/validation.json"
|
22 |
+
sample_stride: 5
|
23 |
+
sample_n_frames: 25
|
24 |
+
relative_pose: true
|
25 |
+
zero_t_first_frame: true
|
26 |
+
sample_size: [320, 576]
|
27 |
+
rescale_fxy: true
|
28 |
+
shuffle_frames: false
|
29 |
+
use_flip: false
|
30 |
+
return_clip_name: true
|
31 |
+
|
32 |
+
random_null_image_ratio: 0.15
|
33 |
+
|
34 |
+
pose_encoder_kwargs:
|
35 |
+
downscale_factor: 8
|
36 |
+
channels: [320, 640, 1280, 1280]
|
37 |
+
nums_rb: 2
|
38 |
+
cin: 384
|
39 |
+
ksize: 1
|
40 |
+
sk: true
|
41 |
+
use_conv: false
|
42 |
+
compression_factor: 1
|
43 |
+
temporal_attention_nhead: 8
|
44 |
+
attention_block_types: ["Temporal_Self", ]
|
45 |
+
temporal_position_encoding: true
|
46 |
+
temporal_position_encoding_max_len: 25
|
47 |
+
|
48 |
+
attention_processor_kwargs:
|
49 |
+
add_spatial: false
|
50 |
+
add_temporal: true
|
51 |
+
attn_processor_name: 'attn1'
|
52 |
+
pose_feature_dimensions: [320, 640, 1280, 1280]
|
53 |
+
query_condition: true
|
54 |
+
key_value_condition: true
|
55 |
+
scale: 1.0
|
56 |
+
|
57 |
+
do_sanity_check: false
|
58 |
+
sample_before_training: false
|
59 |
+
video_length: 25
|
60 |
+
|
61 |
+
max_train_epoch: -1
|
62 |
+
max_train_steps: 50000
|
63 |
+
validation_steps: 2500
|
64 |
+
validation_steps_tuple: [1000, ]
|
65 |
+
|
66 |
+
learning_rate: 3.e-5
|
67 |
+
|
68 |
+
P_mean: 0.7
|
69 |
+
P_std: 1.6
|
70 |
+
condition_image_noise_mean: -3.0
|
71 |
+
condition_image_noise_std: 0.5
|
72 |
+
sample_latent: true
|
73 |
+
first_image_cond: true
|
74 |
+
|
75 |
+
num_inference_steps: 25
|
76 |
+
min_guidance_scale: 1.0
|
77 |
+
max_guidance_scale: 3.0
|
78 |
+
|
79 |
+
num_workers: 8
|
80 |
+
train_batch_size: 1
|
81 |
+
checkpointing_epochs: -1
|
82 |
+
checkpointing_steps: 10000
|
83 |
+
|
84 |
+
mixed_precision_training: false
|
85 |
+
enable_xformers_memory_efficient_attention: true
|
86 |
+
|
87 |
+
global_seed: 42
|
88 |
+
logger_interval: 10
|
inference_cameractrl.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from PIL import Image
|
9 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
10 |
+
from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
|
11 |
+
from diffusers.utils.import_utils import is_xformers_available
|
12 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
13 |
+
from packaging import version as pver
|
14 |
+
|
15 |
+
from cameractrl.pipelines.pipeline_animation import StableVideoDiffusionPipelinePoseCond
|
16 |
+
from cameractrl.models.unet import UNetSpatioTemporalConditionModelPoseCond
|
17 |
+
from cameractrl.models.pose_adaptor import CameraPoseEncoder
|
18 |
+
from cameractrl.utils.util import save_videos_grid
|
19 |
+
|
20 |
+
|
21 |
+
class Camera(object):
|
22 |
+
def __init__(self, entry):
|
23 |
+
fx, fy, cx, cy = entry[1:5]
|
24 |
+
self.fx = fx
|
25 |
+
self.fy = fy
|
26 |
+
self.cx = cx
|
27 |
+
self.cy = cy
|
28 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
29 |
+
w2c_mat_4x4 = np.eye(4)
|
30 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
31 |
+
self.w2c_mat = w2c_mat_4x4
|
32 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
33 |
+
|
34 |
+
|
35 |
+
def setup_for_distributed(is_master):
|
36 |
+
"""
|
37 |
+
This function disables printing when not in master process
|
38 |
+
"""
|
39 |
+
import builtins as __builtin__
|
40 |
+
builtin_print = __builtin__.print
|
41 |
+
|
42 |
+
def print(*args, **kwargs):
|
43 |
+
force = kwargs.pop('force', False)
|
44 |
+
if is_master or force:
|
45 |
+
builtin_print(*args, **kwargs)
|
46 |
+
|
47 |
+
__builtin__.print = print
|
48 |
+
|
49 |
+
|
50 |
+
def custom_meshgrid(*args):
|
51 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
52 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
53 |
+
return torch.meshgrid(*args)
|
54 |
+
else:
|
55 |
+
return torch.meshgrid(*args, indexing='ij')
|
56 |
+
|
57 |
+
|
58 |
+
def get_relative_pose(cam_params, zero_first_frame_scale):
|
59 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
60 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
61 |
+
source_cam_c2w = abs_c2ws[0]
|
62 |
+
if zero_first_frame_scale:
|
63 |
+
cam_to_origin = 0
|
64 |
+
else:
|
65 |
+
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
|
66 |
+
target_cam_c2w = np.array([
|
67 |
+
[1, 0, 0, 0],
|
68 |
+
[0, 1, 0, -cam_to_origin],
|
69 |
+
[0, 0, 1, 0],
|
70 |
+
[0, 0, 0, 1]
|
71 |
+
])
|
72 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
73 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
74 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
75 |
+
return ret_poses
|
76 |
+
|
77 |
+
|
78 |
+
def ray_condition(K, c2w, H, W, device):
|
79 |
+
# c2w: B, V, 4, 4
|
80 |
+
# K: B, V, 4
|
81 |
+
|
82 |
+
B = K.shape[0]
|
83 |
+
|
84 |
+
j, i = custom_meshgrid(
|
85 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
86 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
87 |
+
)
|
88 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
89 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
90 |
+
|
91 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
92 |
+
|
93 |
+
zs = torch.ones_like(i) # [B, HxW]
|
94 |
+
xs = (i - cx) / fx * zs
|
95 |
+
ys = (j - cy) / fy * zs
|
96 |
+
zs = zs.expand_as(ys)
|
97 |
+
|
98 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
99 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
100 |
+
|
101 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
102 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
103 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
104 |
+
# c2w @ dirctions
|
105 |
+
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
106 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
107 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
108 |
+
return plucker
|
109 |
+
|
110 |
+
|
111 |
+
def get_pipeline(ori_model_path, unet_subfolder, down_block_types, up_block_types, pose_encoder_kwargs,
|
112 |
+
attention_processor_kwargs, pose_adaptor_ckpt, enable_xformers, device):
|
113 |
+
noise_scheduler = EulerDiscreteScheduler.from_pretrained(ori_model_path, subfolder="scheduler")
|
114 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(ori_model_path, subfolder="feature_extractor")
|
115 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(ori_model_path, subfolder="image_encoder")
|
116 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(ori_model_path, subfolder="vae")
|
117 |
+
unet = UNetSpatioTemporalConditionModelPoseCond.from_pretrained(ori_model_path,
|
118 |
+
subfolder=unet_subfolder,
|
119 |
+
down_block_types=down_block_types,
|
120 |
+
up_block_types=up_block_types)
|
121 |
+
pose_encoder = CameraPoseEncoder(**pose_encoder_kwargs)
|
122 |
+
print("Setting the attention processors")
|
123 |
+
unet.set_pose_cond_attn_processor(enable_xformers=(enable_xformers and is_xformers_available()), **attention_processor_kwargs)
|
124 |
+
print(f"Loading weights of camera encoder and attention processor from {pose_adaptor_ckpt}")
|
125 |
+
ckpt_dict = torch.load(pose_adaptor_ckpt, map_location=unet.device)
|
126 |
+
pose_encoder_state_dict = ckpt_dict['pose_encoder_state_dict']
|
127 |
+
pose_encoder_m, pose_encoder_u = pose_encoder.load_state_dict(pose_encoder_state_dict)
|
128 |
+
assert len(pose_encoder_m) == 0 and len(pose_encoder_u) == 0
|
129 |
+
attention_processor_state_dict = ckpt_dict['attention_processor_state_dict']
|
130 |
+
_, attention_processor_u = unet.load_state_dict(attention_processor_state_dict, strict=False)
|
131 |
+
assert len(attention_processor_u) == 0
|
132 |
+
print("Loading done")
|
133 |
+
vae.set_attn_processor(AttnProcessor2_0())
|
134 |
+
vae.to(device)
|
135 |
+
image_encoder.to(device)
|
136 |
+
unet.to(device)
|
137 |
+
pipeline = StableVideoDiffusionPipelinePoseCond(
|
138 |
+
vae=vae,
|
139 |
+
image_encoder=image_encoder,
|
140 |
+
unet=unet,
|
141 |
+
scheduler=noise_scheduler,
|
142 |
+
feature_extractor=feature_extractor,
|
143 |
+
pose_encoder=pose_encoder
|
144 |
+
)
|
145 |
+
pipeline = pipeline.to(device)
|
146 |
+
return pipeline
|
147 |
+
|
148 |
+
|
149 |
+
def main(args):
|
150 |
+
os.makedirs(os.path.join(args.out_root, 'generated_videos'), exist_ok=True)
|
151 |
+
os.makedirs(os.path.join(args.out_root, 'reference_images'), exist_ok=True)
|
152 |
+
rank = args.local_rank
|
153 |
+
setup_for_distributed(rank == 0)
|
154 |
+
gpu_id = rank % torch.cuda.device_count()
|
155 |
+
model_configs = OmegaConf.load(args.model_config)
|
156 |
+
device = f"cuda:{gpu_id}"
|
157 |
+
print(f'Constructing pipeline')
|
158 |
+
pipeline = get_pipeline(args.ori_model_path, model_configs['unet_subfolder'], model_configs['down_block_types'],
|
159 |
+
model_configs['up_block_types'], model_configs['pose_encoder_kwargs'],
|
160 |
+
model_configs['attention_processor_kwargs'], args.pose_adaptor_ckpt, args.enable_xformers, device)
|
161 |
+
print('Done')
|
162 |
+
|
163 |
+
print('Loading K, R, t matrix')
|
164 |
+
with open(args.trajectory_file, 'r') as f:
|
165 |
+
poses = f.readlines()
|
166 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
167 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
168 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
169 |
+
|
170 |
+
sample_wh_ratio = args.image_width / args.image_height
|
171 |
+
pose_wh_ratio = args.original_pose_width / args.original_pose_height
|
172 |
+
if pose_wh_ratio > sample_wh_ratio:
|
173 |
+
resized_ori_w = args.image_height * pose_wh_ratio
|
174 |
+
for cam_param in cam_params:
|
175 |
+
cam_param.fx = resized_ori_w * cam_param.fx / args.image_width
|
176 |
+
else:
|
177 |
+
resized_ori_h = args.image_width / pose_wh_ratio
|
178 |
+
for cam_param in cam_params:
|
179 |
+
cam_param.fy = resized_ori_h * cam_param.fy / args.image_height
|
180 |
+
intrinsic = np.asarray([[cam_param.fx * args.image_width,
|
181 |
+
cam_param.fy * args.image_height,
|
182 |
+
cam_param.cx * args.image_width,
|
183 |
+
cam_param.cy * args.image_height]
|
184 |
+
for cam_param in cam_params], dtype=np.float32)
|
185 |
+
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
|
186 |
+
c2ws = get_relative_pose(cam_params, zero_first_frame_scale=True)
|
187 |
+
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
|
188 |
+
plucker_embedding = ray_condition(K, c2ws, args.image_height, args.image_width, device='cpu') # b f h w 6
|
189 |
+
plucker_embedding = plucker_embedding.permute(0, 1, 4, 2, 3).contiguous().to(device=device)
|
190 |
+
|
191 |
+
prompt_dict = json.load(open(args.prompt_file, 'r'))
|
192 |
+
prompt_images = prompt_dict['image_paths']
|
193 |
+
prompt_captions = prompt_dict['captions']
|
194 |
+
N = int(len(prompt_images) // args.n_procs)
|
195 |
+
remainder = int(len(prompt_images) % args.n_procs)
|
196 |
+
prompts_per_gpu = [N + 1 if gpu_id < remainder else N for gpu_id in range(args.n_procs)]
|
197 |
+
low_idx = sum(prompts_per_gpu[:gpu_id])
|
198 |
+
high_idx = low_idx + prompts_per_gpu[gpu_id]
|
199 |
+
prompt_images = prompt_images[low_idx: high_idx]
|
200 |
+
prompt_captions = prompt_captions[low_idx: high_idx]
|
201 |
+
print(f"rank {rank} / {torch.cuda.device_count()}, number of prompts: {len(prompt_images)}")
|
202 |
+
|
203 |
+
generator = torch.Generator(device=device)
|
204 |
+
generator.manual_seed(42)
|
205 |
+
|
206 |
+
for prompt_image, prompt_caption in tqdm(zip(prompt_images, prompt_captions)):
|
207 |
+
save_name = "_".join(prompt_caption.split(" "))
|
208 |
+
condition_image = Image.open(prompt_image)
|
209 |
+
with torch.no_grad():
|
210 |
+
sample = pipeline(
|
211 |
+
image=condition_image,
|
212 |
+
pose_embedding=plucker_embedding,
|
213 |
+
height=args.image_height,
|
214 |
+
width=args.image_width,
|
215 |
+
num_frames=args.num_frames,
|
216 |
+
num_inference_steps=args.num_inference_steps,
|
217 |
+
min_guidance_scale=args.min_guidance_scale,
|
218 |
+
max_guidance_scale=args.max_guidance_scale,
|
219 |
+
do_image_process=True,
|
220 |
+
generator=generator,
|
221 |
+
output_type='pt'
|
222 |
+
).frames[0].transpose(0, 1).cpu() # [3, f, h, w] 0-1
|
223 |
+
resized_condition_image = condition_image.resize((args.image_width, args.image_height))
|
224 |
+
save_videos_grid(sample[None], f"{os.path.join(args.out_root, 'generated_videos')}/{save_name}.mp4", rescale=False)
|
225 |
+
resized_condition_image.save(os.path.join(args.out_root, 'reference_images', f'{save_name}.png'))
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == '__main__':
|
229 |
+
parser = argparse.ArgumentParser()
|
230 |
+
parser.add_argument("--out_root", type=str)
|
231 |
+
parser.add_argument("--image_height", type=int, default=320)
|
232 |
+
parser.add_argument("--image_width", type=int, default=576)
|
233 |
+
parser.add_argument("--num_frames", type=int, default=14)
|
234 |
+
parser.add_argument("--ori_model_path", type=str)
|
235 |
+
parser.add_argument("--unet_subfolder", type=str, default='unet')
|
236 |
+
parser.add_argument("--enable_xformers", action='store_true')
|
237 |
+
parser.add_argument("--pose_adaptor_ckpt", default=None)
|
238 |
+
parser.add_argument("--num_inference_steps", type=int, default=25)
|
239 |
+
parser.add_argument("--min_guidance_scale", type=float, default=1.0)
|
240 |
+
parser.add_argument("--max_guidance_scale", type=float, default=3.0)
|
241 |
+
parser.add_argument("--prompt_file", required=True, help='prompts path, json or txt')
|
242 |
+
parser.add_argument("--trajectory_file", required=True)
|
243 |
+
parser.add_argument("--original_pose_width", type=int, default=1280)
|
244 |
+
parser.add_argument("--original_pose_height", type=int, default=720)
|
245 |
+
parser.add_argument("--model_config", required=True)
|
246 |
+
parser.add_argument("--n_procs", type=int, default=8)
|
247 |
+
|
248 |
+
# DDP args
|
249 |
+
parser.add_argument("--world_size", default=1, type=int,
|
250 |
+
help="number of the distributed processes.")
|
251 |
+
parser.add_argument('--local-rank', type=int, default=-1,
|
252 |
+
help='Replica rank on the current node. This field is required '
|
253 |
+
'by `torch.distributed.launch`.')
|
254 |
+
args = parser.parse_args()
|
255 |
+
main(args)
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
diffusers==0.24.0
|
5 |
+
imageio==2.27.0
|
6 |
+
transformers==4.39.3
|
7 |
+
gradio==4.26.0
|
8 |
+
imageio==2.27.0
|
9 |
+
imageio-ffmpeg==0.4.9
|
10 |
+
accelerate==0.30.0
|
11 |
+
opencv-python
|
12 |
+
gdown
|
13 |
+
einops
|
14 |
+
decord
|
15 |
+
omegaconf
|
16 |
+
safetensors
|
17 |
+
gradio
|
18 |
+
wandb
|
19 |
+
triton
|
20 |
+
termcolor
|