Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
0f56e8b
1
Parent(s):
d794a86
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- app.py +284 -0
- demo.py +377 -0
- docs/config_help.md +27 -0
- extern/depthcrafter/__init__.py +0 -0
- extern/depthcrafter/__pycache__/__init__.cpython-310.pyc +0 -0
- extern/depthcrafter/__pycache__/demo.cpython-310.pyc +0 -0
- extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc +0 -0
- extern/depthcrafter/__pycache__/infer.cpython-310.pyc +0 -0
- extern/depthcrafter/__pycache__/unet.cpython-310.pyc +0 -0
- extern/depthcrafter/__pycache__/utils.cpython-310.pyc +0 -0
- extern/depthcrafter/depth_crafter_ppl.py +366 -0
- extern/depthcrafter/infer.py +91 -0
- extern/depthcrafter/unet.py +142 -0
- extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc +0 -0
- extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc +0 -0
- extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc +0 -0
- extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc +0 -0
- extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2.py +415 -0
- extern/video_depth_anything/dinov2_layers/__init__.py +11 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
- extern/video_depth_anything/dinov2_layers/attention.py +83 -0
- extern/video_depth_anything/dinov2_layers/block.py +252 -0
- extern/video_depth_anything/dinov2_layers/drop_path.py +35 -0
- extern/video_depth_anything/dinov2_layers/layer_scale.py +28 -0
- extern/video_depth_anything/dinov2_layers/mlp.py +41 -0
- extern/video_depth_anything/dinov2_layers/patch_embed.py +89 -0
- extern/video_depth_anything/dinov2_layers/swiglu_ffn.py +63 -0
- extern/video_depth_anything/dpt.py +160 -0
- extern/video_depth_anything/dpt_temporal.py +96 -0
- extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc +0 -0
- extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc +0 -0
- extern/video_depth_anything/motion_module/attention.py +429 -0
- extern/video_depth_anything/motion_module/motion_module.py +297 -0
- extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc +0 -0
- extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc +0 -0
- extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc +0 -0
- extern/video_depth_anything/util/blocks.py +162 -0
- extern/video_depth_anything/util/transform.py +158 -0
- extern/video_depth_anything/util/util.py +74 -0
- extern/video_depth_anything/vdademo.py +63 -0
- extern/video_depth_anything/video_depth.py +154 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import sys
|
4 |
+
from demo import TrajCrafter
|
5 |
+
import random
|
6 |
+
import gradio as gr
|
7 |
+
import random
|
8 |
+
from inference import get_parser
|
9 |
+
from datetime import datetime
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
# 解析命令行参数
|
13 |
+
|
14 |
+
traj_examples = [
|
15 |
+
['20; -30; 0.3; 0; 0'],
|
16 |
+
['0; 0; -0.3; -2; 2'],
|
17 |
+
]
|
18 |
+
|
19 |
+
# inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
|
20 |
+
|
21 |
+
img_examples = [
|
22 |
+
['test/videos/0-NNvgaTcVzAG0-r.mp4',2,1,'0; -30; 0.5; -2; 0',50,43],
|
23 |
+
['test/videos/tUfDESZsQFhdDW9S.mp4',2,1,'0; 30; -0.4; 2; 0',50,43],
|
24 |
+
['test/videos/part-2-3.mp4',2,1,'20; 40; 0.5; 2; 0',50,43],
|
25 |
+
['test/videos/p7.mp4',2,1,'0; -50; 0.3; 0; 0',50,43],
|
26 |
+
['test/videos/UST-fn-RvhJwMR5S.mp4',2,1,'0; -35; 0.4; 0; 0',50,43],
|
27 |
+
]
|
28 |
+
|
29 |
+
max_seed = 2 ** 31
|
30 |
+
|
31 |
+
parser = get_parser() # infer_config.py
|
32 |
+
opts = parser.parse_args() # default device: 'cuda:0'
|
33 |
+
opts.weight_dtype = torch.bfloat16
|
34 |
+
tmp = datetime.now().strftime("%Y%m%d_%H%M")
|
35 |
+
opts.save_dir = f'./experiments/gradio_{tmp}'
|
36 |
+
os.makedirs(opts.save_dir,exist_ok=True)
|
37 |
+
test_tensor = torch.Tensor([0]).cuda()
|
38 |
+
opts.device = str(test_tensor.device)
|
39 |
+
|
40 |
+
CAMERA_MOTION_MODE = ["Basic Camera Trajectory", "Custom Camera Trajectory"]
|
41 |
+
|
42 |
+
def show_traj(mode):
|
43 |
+
if mode == 'Orbit Left':
|
44 |
+
return gr.update(value='0; -30; 0; 0; 0',visible=True),gr.update(visible=False)
|
45 |
+
elif mode == 'Orbit Right':
|
46 |
+
return gr.update(value='0; 30; 0; 0; 0',visible=True),gr.update(visible=False)
|
47 |
+
elif mode == 'Orbit Up':
|
48 |
+
return gr.update(value='30; 0; 0; 0; 0',visible=True),gr.update(visible=False)
|
49 |
+
elif mode == 'Orbit Down':
|
50 |
+
return gr.update(value='-20; 0; 0; 0; 0',visible=True), gr.update(visible=False)
|
51 |
+
if mode == 'Pan Left':
|
52 |
+
return gr.update(value='0; 0; 0; -2; 0',visible=True),gr.update(visible=False)
|
53 |
+
elif mode == 'Pan Right':
|
54 |
+
return gr.update(value='0; 0; 0; 2; 0',visible=True),gr.update(visible=False)
|
55 |
+
elif mode == 'Pan Up':
|
56 |
+
return gr.update(value='0; 0; 0; 0; 2',visible=True),gr.update(visible=False)
|
57 |
+
elif mode == 'Pan Down':
|
58 |
+
return gr.update(value='0; 0; 0; 0; -2',visible=True), gr.update(visible=False)
|
59 |
+
elif mode == 'Zoom in':
|
60 |
+
return gr.update(value='0; 0; 0.5; 0; 0',visible=True), gr.update(visible=False)
|
61 |
+
elif mode == 'Zoom out':
|
62 |
+
return gr.update(value='0; 0; -0.5; 0; 0',visible=True), gr.update(visible=False)
|
63 |
+
elif mode == 'Customize':
|
64 |
+
return gr.update(value='0; 0; 0; 0; 0',visible=True), gr.update(visible=True)
|
65 |
+
elif mode == 'Reset':
|
66 |
+
return gr.update(value='0; 0; 0; 0; 0',visible=False), gr.update(visible=False)
|
67 |
+
|
68 |
+
def trajcrafter_demo(opts):
|
69 |
+
# css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
|
70 |
+
css = """
|
71 |
+
#input_img {max-width: 1024px !important}
|
72 |
+
#output_vid {max-width: 1024px; max-height:576px}
|
73 |
+
#random_button {max-width: 100px !important}
|
74 |
+
.generate-btn {
|
75 |
+
background: linear-gradient(45deg, #2196F3, #1976D2) !important;
|
76 |
+
border: none !important;
|
77 |
+
color: white !important;
|
78 |
+
font-weight: bold !important;
|
79 |
+
box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
|
80 |
+
}
|
81 |
+
.generate-btn:hover {
|
82 |
+
background: linear-gradient(45deg, #1976D2, #1565C0) !important;
|
83 |
+
box-shadow: 0 4px 8px rgba(0,0,0,0.3) !important;
|
84 |
+
}
|
85 |
+
"""
|
86 |
+
image2video = TrajCrafter(opts,gradio=True)
|
87 |
+
# image2video.run_both = spaces.GPU(image2video.run_both, duration=290) # fixme
|
88 |
+
with gr.Blocks(analytics_enabled=False, css=css) as trajcrafter_iface:
|
89 |
+
gr.Markdown("<div align='center'> <h1> TrajectoryCrafter: Redirecting View Trajectory for Monocular Videos via Diffusion Models </span> </h1>")
|
90 |
+
# # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
91 |
+
# # <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02048'> [ArXiv] </a>\
|
92 |
+
# # <a style='font-size:18px;color: #000000' href='https://drexubery.github.io/ViewCrafter/'> [Project Page] </a>\
|
93 |
+
# # <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Drexubery/ViewCrafter'> [Github] </a>\
|
94 |
+
# # <a style='font-size:18px;color: #000000' href='https://www.youtube.com/watch?v=WGIEmu9eXmU'> [Video] </a> </div>")
|
95 |
+
|
96 |
+
|
97 |
+
with gr.Row(equal_height=True):
|
98 |
+
with gr.Column():
|
99 |
+
# # step 1: input an image
|
100 |
+
# gr.Markdown("---\n## Step 1: Input an Image, selet an elevation angle and a center_scale factor", show_label=False, visible=True)
|
101 |
+
# gr.Markdown("<div align='left' style='font-size:18px;color: #000000'>1. Estimate an elevation angle that represents the angle at which the image was taken; a value bigger than 0 indicates a top-down view, and it doesn't need to be precise. <br>2. The origin of the world coordinate system is by default defined at the point cloud corresponding to the center pixel of the input image. You can adjust the position of the origin by modifying center_scale; a value smaller than 1 brings the origin closer to you.</div>")
|
102 |
+
i2v_input_video = gr.Video(label="Input Video", elem_id="input_video", format="mp4")
|
103 |
+
|
104 |
+
|
105 |
+
with gr.Column():
|
106 |
+
i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
|
107 |
+
show_share_button=True)
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
with gr.Row():
|
111 |
+
i2v_stride = gr.Slider(minimum=1, maximum=3, step=1, elem_id="stride", label="Stride", value=1)
|
112 |
+
i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale",
|
113 |
+
label="center_scale", value=1)
|
114 |
+
i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps",
|
115 |
+
value=50)
|
116 |
+
i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
|
117 |
+
with gr.Row():
|
118 |
+
pan_left = gr.Button(value="Pan Left")
|
119 |
+
pan_right = gr.Button(value="Pan Right")
|
120 |
+
pan_up = gr.Button(value="Pan Up")
|
121 |
+
pan_down = gr.Button(value="Pan Down")
|
122 |
+
with gr.Row():
|
123 |
+
orbit_left = gr.Button(value="Orbit Left")
|
124 |
+
orbit_right = gr.Button(value="Orbit Right")
|
125 |
+
orbit_up = gr.Button(value="Orbit Up")
|
126 |
+
orbit_down = gr.Button(value="Orbit Down")
|
127 |
+
with gr.Row():
|
128 |
+
zin = gr.Button(value="Zoom in")
|
129 |
+
zout = gr.Button(value="Zoom out")
|
130 |
+
custom = gr.Button(value="Customize")
|
131 |
+
reset = gr.Button(value="Reset")
|
132 |
+
with gr.Column():
|
133 |
+
with gr.Row():
|
134 |
+
with gr.Column():
|
135 |
+
i2v_pose = gr.Text(value='0; 0; 0; 0; 0', label="Traget camera pose (theta, phi, r, x, y)",
|
136 |
+
visible=False)
|
137 |
+
with gr.Column(visible=False) as i2v_egs:
|
138 |
+
gr.Markdown(
|
139 |
+
"<div align='left' style='font-size:18px;color: #000000'>Please refer to <a href='https://github.com/TrajectoryCrafter/TrajectoryCrafter/blob/main/docs/config_help.md' target='_blank'>tutorial</a> for customizing camera trajectory.</div>")
|
140 |
+
gr.Examples(examples=traj_examples,
|
141 |
+
inputs=[i2v_pose],
|
142 |
+
)
|
143 |
+
with gr.Column():
|
144 |
+
i2v_end_btn = gr.Button("Generate video", scale=2, size="lg", variant="primary", elem_classes="generate-btn")
|
145 |
+
|
146 |
+
|
147 |
+
# with gr.Column():
|
148 |
+
# i2v_input_video = gr.Video(label="Input Video", elem_id="input_video", format="mp4")
|
149 |
+
# i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
|
150 |
+
# with gr.Row():
|
151 |
+
# # i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
|
152 |
+
# i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
|
153 |
+
# i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
|
154 |
+
# i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
|
155 |
+
# with gr.Column():
|
156 |
+
# with gr.Row():
|
157 |
+
# left = gr.Button(value = "Left")
|
158 |
+
# right = gr.Button(value = "Right")
|
159 |
+
# up = gr.Button(value = "Up")
|
160 |
+
# with gr.Row():
|
161 |
+
# down = gr.Button(value = "Down")
|
162 |
+
# zin = gr.Button(value = "Zoom in")
|
163 |
+
# zout = gr.Button(value = "Zoom out")
|
164 |
+
# with gr.Row():
|
165 |
+
# custom = gr.Button(value = "Customize")
|
166 |
+
# reset = gr.Button(value = "Reset")
|
167 |
+
|
168 |
+
|
169 |
+
# step 3 - Generate video
|
170 |
+
# with gr.Column():
|
171 |
+
# gr.Markdown("---\n## Step 3: Generate video", show_label=False, visible=True)
|
172 |
+
# gr.Markdown("<div align='left' style='font-size:18px;color: #000000'> You can reduce the sampling steps for faster inference; try different random seed if the result is not satisfying. </div>")
|
173 |
+
# i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
|
174 |
+
# i2v_end_btn = gr.Button("Generate video")
|
175 |
+
# i2v_traj_video = gr.Video(label="Camera Trajectory",elem_id="traj_vid",autoplay=True,show_share_button=True)
|
176 |
+
|
177 |
+
# with gr.Column(scale=1.5):
|
178 |
+
# with gr.Row():
|
179 |
+
# # i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
|
180 |
+
# i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
|
181 |
+
# i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
|
182 |
+
# i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
|
183 |
+
# with gr.Row():
|
184 |
+
# pan_left = gr.Button(value = "Pan Left")
|
185 |
+
# pan_right = gr.Button(value = "Pan Right")
|
186 |
+
# pan_up = gr.Button(value = "Pan Up")
|
187 |
+
# pan_down = gr.Button(value = "Pan Down")
|
188 |
+
# with gr.Row():
|
189 |
+
# orbit_left = gr.Button(value = "Orbit Left")
|
190 |
+
# orbit_right = gr.Button(value = "Orbit Right")
|
191 |
+
# orbit_up = gr.Button(value = "Orbit Up")
|
192 |
+
# orbit_down = gr.Button(value = "Orbit Down")
|
193 |
+
# with gr.Row():
|
194 |
+
# zin = gr.Button(value = "Zoom in")
|
195 |
+
# zout = gr.Button(value = "Zoom out")
|
196 |
+
# custom = gr.Button(value = "Customize")
|
197 |
+
# reset = gr.Button(value = "Reset")
|
198 |
+
# with gr.Column():
|
199 |
+
# with gr.Row():
|
200 |
+
# with gr.Column():
|
201 |
+
# i2v_pose = gr.Text(value = '0; 0; 0; 0; 0', label="Traget camera pose (theta, phi, r, x, y)",visible=False)
|
202 |
+
# with gr.Column(visible=False) as i2v_egs:
|
203 |
+
# gr.Markdown("<div align='left' style='font-size:18px;color: #000000'>Please refer to the <a href='https://github.com/Drexubery/ViewCrafter/blob/main/docs/gradio_tutorial.md' target='_blank'>tutorial</a> for customizing camera trajectory.</div>")
|
204 |
+
# gr.Examples(examples=traj_examples,
|
205 |
+
# inputs=[i2v_pose],
|
206 |
+
# )
|
207 |
+
# with gr.Row():
|
208 |
+
# i2v_end_btn = gr.Button("Generate video")
|
209 |
+
# step 3 - Generate video
|
210 |
+
# with gr.Row():
|
211 |
+
# with gr.Column():
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
i2v_end_btn.click(inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
|
216 |
+
outputs=[i2v_output_video],
|
217 |
+
fn = image2video.run_gradio
|
218 |
+
)
|
219 |
+
|
220 |
+
pan_left.click(inputs=[pan_left],
|
221 |
+
outputs=[i2v_pose,i2v_egs],
|
222 |
+
fn = show_traj
|
223 |
+
)
|
224 |
+
pan_right.click(inputs=[pan_right],
|
225 |
+
outputs=[i2v_pose,i2v_egs],
|
226 |
+
fn = show_traj
|
227 |
+
)
|
228 |
+
pan_up.click(inputs=[pan_up],
|
229 |
+
outputs=[i2v_pose,i2v_egs],
|
230 |
+
fn = show_traj
|
231 |
+
)
|
232 |
+
pan_down.click(inputs=[pan_down],
|
233 |
+
outputs=[i2v_pose,i2v_egs],
|
234 |
+
fn = show_traj
|
235 |
+
)
|
236 |
+
orbit_left.click(inputs=[orbit_left],
|
237 |
+
outputs=[i2v_pose,i2v_egs],
|
238 |
+
fn = show_traj
|
239 |
+
)
|
240 |
+
orbit_right.click(inputs=[orbit_right],
|
241 |
+
outputs=[i2v_pose,i2v_egs],
|
242 |
+
fn = show_traj
|
243 |
+
)
|
244 |
+
orbit_up.click(inputs=[orbit_up],
|
245 |
+
outputs=[i2v_pose,i2v_egs],
|
246 |
+
fn = show_traj
|
247 |
+
)
|
248 |
+
orbit_down.click(inputs=[orbit_down],
|
249 |
+
outputs=[i2v_pose,i2v_egs],
|
250 |
+
fn = show_traj
|
251 |
+
)
|
252 |
+
zin.click(inputs=[zin],
|
253 |
+
outputs=[i2v_pose,i2v_egs],
|
254 |
+
fn = show_traj
|
255 |
+
)
|
256 |
+
zout.click(inputs=[zout],
|
257 |
+
outputs=[i2v_pose,i2v_egs],
|
258 |
+
fn = show_traj
|
259 |
+
)
|
260 |
+
custom.click(inputs=[custom],
|
261 |
+
outputs=[i2v_pose,i2v_egs],
|
262 |
+
fn = show_traj
|
263 |
+
)
|
264 |
+
reset.click(inputs=[reset],
|
265 |
+
outputs=[i2v_pose,i2v_egs],
|
266 |
+
fn = show_traj
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
gr.Examples(examples=img_examples,
|
271 |
+
# inputs=[i2v_input_video,i2v_stride],
|
272 |
+
inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
|
273 |
+
)
|
274 |
+
|
275 |
+
return trajcrafter_iface
|
276 |
+
|
277 |
+
|
278 |
+
trajcrafter_iface = trajcrafter_demo(opts)
|
279 |
+
trajcrafter_iface.queue(max_size=10)
|
280 |
+
# trajcrafter_iface.launch(server_name=args.server_name, max_threads=10, debug=True)
|
281 |
+
trajcrafter_iface.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False, max_threads=10)
|
282 |
+
|
283 |
+
|
284 |
+
|
demo.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from extern.depthcrafter.infer import DepthCrafterDemo
|
5 |
+
# from extern.video_depth_anything.vdademo import VDADemo
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from transformers import T5EncoderModel
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from PIL import Image
|
11 |
+
from models.crosstransformer3d import CrossTransformer3DModel
|
12 |
+
from models.autoencoder_magvit import AutoencoderKLCogVideoX
|
13 |
+
from models.pipeline_trajectorycrafter import TrajCrafter_Pipeline
|
14 |
+
from models.utils import *
|
15 |
+
from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
|
16 |
+
DPMSolverMultistepScheduler,
|
17 |
+
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
18 |
+
PNDMScheduler)
|
19 |
+
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
20 |
+
|
21 |
+
class TrajCrafter:
|
22 |
+
def __init__(self, opts, gradio=False):
|
23 |
+
self.funwarp = Warper(device=opts.device)
|
24 |
+
# self.depth_estimater = VDADemo(pre_train_path=opts.pre_train_path_vda,device=opts.device)
|
25 |
+
self.depth_estimater = DepthCrafterDemo(unet_path=opts.unet_path,pre_train_path=opts.pre_train_path,cpu_offload=opts.cpu_offload,device=opts.device)
|
26 |
+
self.caption_processor = AutoProcessor.from_pretrained(opts.blip_path)
|
27 |
+
self.captioner = Blip2ForConditionalGeneration.from_pretrained(opts.blip_path, torch_dtype=torch.float16).to(opts.device)
|
28 |
+
self.setup_diffusion(opts)
|
29 |
+
if gradio:
|
30 |
+
self.opts=opts
|
31 |
+
|
32 |
+
def infer_gradual(self,opts):
|
33 |
+
frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
|
34 |
+
prompt = self.get_caption(opts,frames[opts.video_length//2])
|
35 |
+
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
|
36 |
+
depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
|
37 |
+
frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
|
38 |
+
assert frames.shape[0] == opts.video_length
|
39 |
+
pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.video_length)
|
40 |
+
warped_images = []
|
41 |
+
masks = []
|
42 |
+
for i in tqdm(range(opts.video_length)):
|
43 |
+
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, opts.mask,twice=False)
|
44 |
+
warped_images.append(warped_frame2)
|
45 |
+
masks.append(mask2)
|
46 |
+
cond_video = (torch.cat(warped_images)+1.)/2.
|
47 |
+
cond_masks = torch.cat(masks)
|
48 |
+
|
49 |
+
frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
|
50 |
+
cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
|
51 |
+
cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
|
52 |
+
save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
|
53 |
+
save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
|
54 |
+
save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
|
55 |
+
|
56 |
+
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
|
57 |
+
frames_ref = frames[:,:,:10,:,:]
|
58 |
+
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
|
59 |
+
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
|
60 |
+
generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
|
61 |
+
|
62 |
+
del self.depth_estimater
|
63 |
+
del self.caption_processor
|
64 |
+
del self.captioner
|
65 |
+
gc.collect()
|
66 |
+
torch.cuda.empty_cache()
|
67 |
+
with torch.no_grad():
|
68 |
+
sample = self.pipeline(
|
69 |
+
prompt,
|
70 |
+
num_frames = opts.video_length,
|
71 |
+
negative_prompt = opts.negative_prompt,
|
72 |
+
height = opts.sample_size[0],
|
73 |
+
width = opts.sample_size[1],
|
74 |
+
generator = generator,
|
75 |
+
guidance_scale = opts.diffusion_guidance_scale,
|
76 |
+
num_inference_steps = opts.diffusion_inference_steps,
|
77 |
+
video = cond_video,
|
78 |
+
mask_video = cond_masks,
|
79 |
+
reference = frames_ref,
|
80 |
+
).videos
|
81 |
+
save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
|
82 |
+
|
83 |
+
viz = True
|
84 |
+
if viz:
|
85 |
+
tensor_left = frames[0].to(opts.device)
|
86 |
+
tensor_right = sample[0].to(opts.device)
|
87 |
+
interval = torch.ones(3, 49, 384, 30).to(opts.device)
|
88 |
+
result = torch.cat((tensor_left, interval, tensor_right), dim=3)
|
89 |
+
result_reverse = torch.flip(result, dims=[1])
|
90 |
+
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
|
91 |
+
save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2)
|
92 |
+
|
93 |
+
def infer_direct(self,opts):
|
94 |
+
opts.cut = 20
|
95 |
+
frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
|
96 |
+
prompt = self.get_caption(opts,frames[opts.video_length//2])
|
97 |
+
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
|
98 |
+
depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
|
99 |
+
frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
|
100 |
+
assert frames.shape[0] == opts.video_length
|
101 |
+
pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.cut)
|
102 |
+
|
103 |
+
warped_images = []
|
104 |
+
masks = []
|
105 |
+
for i in tqdm(range(opts.video_length)):
|
106 |
+
if i < opts.cut:
|
107 |
+
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[0:1], None, depths[0:1], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False)
|
108 |
+
warped_images.append(warped_frame2)
|
109 |
+
masks.append(mask2)
|
110 |
+
else:
|
111 |
+
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i-opts.cut:i-opts.cut+1], None, depths[i-opts.cut:i-opts.cut+1], pose_s[0:1], pose_t[-1:], K[0:1], None, opts.mask,twice=False)
|
112 |
+
warped_images.append(warped_frame2)
|
113 |
+
masks.append(mask2)
|
114 |
+
cond_video = (torch.cat(warped_images)+1.)/2.
|
115 |
+
cond_masks = torch.cat(masks)
|
116 |
+
frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
|
117 |
+
cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
|
118 |
+
cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
|
119 |
+
save_video((frames[:opts.video_length-opts.cut].permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
|
120 |
+
save_video(cond_video[opts.cut:].permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
|
121 |
+
save_video(cond_masks[opts.cut:].repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
|
122 |
+
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
|
123 |
+
frames_ref = frames[:,:,:10,:,:]
|
124 |
+
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
|
125 |
+
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
|
126 |
+
generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
|
127 |
+
|
128 |
+
del self.depth_estimater
|
129 |
+
del self.caption_processor
|
130 |
+
del self.captioner
|
131 |
+
gc.collect()
|
132 |
+
torch.cuda.empty_cache()
|
133 |
+
with torch.no_grad():
|
134 |
+
sample = self.pipeline(
|
135 |
+
prompt,
|
136 |
+
num_frames = opts.video_length,
|
137 |
+
negative_prompt = opts.negative_prompt,
|
138 |
+
height = opts.sample_size[0],
|
139 |
+
width = opts.sample_size[1],
|
140 |
+
generator = generator,
|
141 |
+
guidance_scale = opts.diffusion_guidance_scale,
|
142 |
+
num_inference_steps = opts.diffusion_inference_steps,
|
143 |
+
video = cond_video,
|
144 |
+
mask_video = cond_masks,
|
145 |
+
reference = frames_ref,
|
146 |
+
).videos
|
147 |
+
save_video(sample[0].permute(1,2,3,0)[opts.cut:], os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
|
148 |
+
|
149 |
+
viz = True
|
150 |
+
if viz:
|
151 |
+
tensor_left = frames[0][:,:opts.video_length-opts.cut,...].to(opts.device)
|
152 |
+
tensor_right = sample[0][:,opts.cut:,...].to(opts.device)
|
153 |
+
interval = torch.ones(3, opts.video_length-opts.cut, 384, 30).to(opts.device)
|
154 |
+
result = torch.cat((tensor_left, interval, tensor_right), dim=3)
|
155 |
+
result_reverse = torch.flip(result, dims=[1])
|
156 |
+
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
|
157 |
+
save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2)
|
158 |
+
|
159 |
+
def infer_bullet(self,opts):
|
160 |
+
frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
|
161 |
+
prompt = self.get_caption(opts,frames[opts.video_length//2])
|
162 |
+
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
|
163 |
+
depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
|
164 |
+
|
165 |
+
frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
|
166 |
+
assert frames.shape[0] == opts.video_length
|
167 |
+
pose_s, pose_t, K = self.get_poses(opts,depths, num_frames = opts.video_length)
|
168 |
+
|
169 |
+
warped_images = []
|
170 |
+
masks = []
|
171 |
+
for i in tqdm(range(opts.video_length)):
|
172 |
+
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[-1:], None, depths[-1:], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False)
|
173 |
+
warped_images.append(warped_frame2)
|
174 |
+
masks.append(mask2)
|
175 |
+
cond_video = (torch.cat(warped_images)+1.)/2.
|
176 |
+
cond_masks = torch.cat(masks)
|
177 |
+
frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
|
178 |
+
cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
|
179 |
+
cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
|
180 |
+
save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
|
181 |
+
save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
|
182 |
+
save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
|
183 |
+
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
|
184 |
+
frames_ref = frames[:,:,-10:,:,:]
|
185 |
+
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
|
186 |
+
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
|
187 |
+
generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
|
188 |
+
|
189 |
+
del self.depth_estimater
|
190 |
+
del self.caption_processor
|
191 |
+
del self.captioner
|
192 |
+
gc.collect()
|
193 |
+
torch.cuda.empty_cache()
|
194 |
+
with torch.no_grad():
|
195 |
+
sample = self.pipeline(
|
196 |
+
prompt,
|
197 |
+
num_frames = opts.video_length,
|
198 |
+
negative_prompt = opts.negative_prompt,
|
199 |
+
height = opts.sample_size[0],
|
200 |
+
width = opts.sample_size[1],
|
201 |
+
generator = generator,
|
202 |
+
guidance_scale = opts.diffusion_guidance_scale,
|
203 |
+
num_inference_steps = opts.diffusion_inference_steps,
|
204 |
+
video = cond_video,
|
205 |
+
mask_video = cond_masks,
|
206 |
+
reference = frames_ref,
|
207 |
+
).videos
|
208 |
+
save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
|
209 |
+
|
210 |
+
viz = True
|
211 |
+
if viz:
|
212 |
+
tensor_left = frames[0].to(opts.device)
|
213 |
+
tensor_left_full = torch.cat([tensor_left,tensor_left[:,-1:,:,:].repeat(1,48,1,1)],dim=1)
|
214 |
+
tensor_right = sample[0].to(opts.device)
|
215 |
+
tensor_right_full = torch.cat([tensor_left,tensor_right[:,1:,:,:]],dim=1)
|
216 |
+
interval = torch.ones(3, 49*2-1, 384, 30).to(opts.device)
|
217 |
+
result = torch.cat((tensor_left_full, interval, tensor_right_full), dim=3)
|
218 |
+
result_reverse = torch.flip(result, dims=[1])
|
219 |
+
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
|
220 |
+
save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*4)
|
221 |
+
|
222 |
+
def get_caption(self,opts,image):
|
223 |
+
image_array = (image * 255).astype(np.uint8)
|
224 |
+
pil_image = Image.fromarray(image_array)
|
225 |
+
inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(opts.device, torch.float16)
|
226 |
+
generated_ids = self.captioner.generate(**inputs)
|
227 |
+
generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
228 |
+
return generated_text + opts.refine_prompt
|
229 |
+
|
230 |
+
def get_poses(self,opts,depths,num_frames):
|
231 |
+
radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*opts.radius_scale
|
232 |
+
radius = min(radius, 5)
|
233 |
+
cx = 512. #depths.shape[-1]//2
|
234 |
+
cy = 288. #depths.shape[-2]//2
|
235 |
+
f = 500 #500.
|
236 |
+
K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(opts.device)
|
237 |
+
c2w_init = torch.tensor([[-1., 0., 0., 0.],
|
238 |
+
[ 0., 1., 0., 0.],
|
239 |
+
[ 0., 0., -1., 0.],
|
240 |
+
[ 0., 0., 0., 1.]]).to(opts.device).unsqueeze(0)
|
241 |
+
if opts.camera == 'target':
|
242 |
+
dtheta, dphi, dr, dx, dy = opts.target_pose
|
243 |
+
poses = generate_traj_specified(c2w_init, dtheta, dphi, dr*radius, dx, dy, num_frames, opts.device)
|
244 |
+
elif opts.camera =='traj':
|
245 |
+
with open(opts.traj_txt, 'r') as file:
|
246 |
+
lines = file.readlines()
|
247 |
+
theta = [float(i) for i in lines[0].split()]
|
248 |
+
phi = [float(i) for i in lines[1].split()]
|
249 |
+
r = [float(i)*radius for i in lines[2].split()]
|
250 |
+
poses = generate_traj_txt(c2w_init, phi, theta, r, num_frames, opts.device)
|
251 |
+
poses[:,2, 3] = poses[:,2, 3] + radius
|
252 |
+
pose_s = poses[opts.anchor_idx:opts.anchor_idx+1].repeat(num_frames,1,1)
|
253 |
+
pose_t = poses
|
254 |
+
return pose_s, pose_t, K
|
255 |
+
|
256 |
+
def setup_diffusion(self,opts):
|
257 |
+
# transformer = CrossTransformer3DModel.from_pretrained_cus(opts.transformer_path).to(opts.weight_dtype)
|
258 |
+
transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path).to(opts.weight_dtype)
|
259 |
+
# transformer = transformer.to(opts.weight_dtype)
|
260 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(
|
261 |
+
opts.model_name,
|
262 |
+
subfolder="vae"
|
263 |
+
).to(opts.weight_dtype)
|
264 |
+
text_encoder = T5EncoderModel.from_pretrained(
|
265 |
+
opts.model_name, subfolder="text_encoder", torch_dtype=opts.weight_dtype
|
266 |
+
)
|
267 |
+
# Get Scheduler
|
268 |
+
Choosen_Scheduler = {
|
269 |
+
"Euler": EulerDiscreteScheduler,
|
270 |
+
"Euler A": EulerAncestralDiscreteScheduler,
|
271 |
+
"DPM++": DPMSolverMultistepScheduler,
|
272 |
+
"PNDM": PNDMScheduler,
|
273 |
+
"DDIM_Cog": CogVideoXDDIMScheduler,
|
274 |
+
"DDIM_Origin": DDIMScheduler,
|
275 |
+
}[opts.sampler_name]
|
276 |
+
scheduler = Choosen_Scheduler.from_pretrained(
|
277 |
+
opts.model_name,
|
278 |
+
subfolder="scheduler"
|
279 |
+
)
|
280 |
+
|
281 |
+
self.pipeline = TrajCrafter_Pipeline.from_pretrained(
|
282 |
+
opts.model_name,
|
283 |
+
vae=vae,
|
284 |
+
text_encoder=text_encoder,
|
285 |
+
transformer=transformer,
|
286 |
+
scheduler=scheduler,
|
287 |
+
torch_dtype=opts.weight_dtype
|
288 |
+
)
|
289 |
+
|
290 |
+
if opts.low_gpu_memory_mode:
|
291 |
+
self.pipeline.enable_sequential_cpu_offload()
|
292 |
+
else:
|
293 |
+
self.pipeline.enable_model_cpu_offload()
|
294 |
+
|
295 |
+
def run_gradio(self,input_video, stride, radius_scale, pose, steps, seed):
|
296 |
+
frames = read_video_frames(input_video, self.opts.video_length, stride,self.opts.max_res)
|
297 |
+
prompt = self.get_caption(self.opts,frames[self.opts.video_length//2])
|
298 |
+
# depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
|
299 |
+
depths= self.depth_estimater.infer(frames, self.opts.near, self.opts.far, self.opts.depth_inference_steps, self.opts.depth_guidance_scale, window_size=self.opts.window_size, overlap=self.opts.overlap).to(self.opts.device)
|
300 |
+
frames = torch.from_numpy(frames).permute(0,3,1,2).to(self.opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
|
301 |
+
num_frames = frames.shape[0]
|
302 |
+
assert num_frames == self.opts.video_length
|
303 |
+
radius_scale = float(radius_scale)
|
304 |
+
radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*radius_scale
|
305 |
+
radius = min(radius, 5)
|
306 |
+
cx = 512. #depths.shape[-1]//2
|
307 |
+
cy = 288. #depths.shape[-2]//2
|
308 |
+
f = 500 #500.
|
309 |
+
K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(self.opts.device)
|
310 |
+
c2w_init = torch.tensor([[-1., 0., 0., 0.],
|
311 |
+
[ 0., 1., 0., 0.],
|
312 |
+
[ 0., 0., -1., 0.],
|
313 |
+
[ 0., 0., 0., 1.]]).to(self.opts.device).unsqueeze(0)
|
314 |
+
|
315 |
+
# import pdb
|
316 |
+
# pdb.set_trace()
|
317 |
+
theta,phi,r,x,y = [float(i) for i in pose.split(';')]
|
318 |
+
# theta,phi,r,x,y = [float(i) for i in theta.split()],[float(i) for i in phi.split()],[float(i) for i in r.split()],[float(i) for i in x.split()],[float(i) for i in y.split()]
|
319 |
+
# target mode
|
320 |
+
poses = generate_traj_specified(c2w_init, theta, phi, r*radius, x, y, num_frames, self.opts.device)
|
321 |
+
poses[:,2, 3] = poses[:,2, 3] + radius
|
322 |
+
pose_s = poses[self.opts.anchor_idx:self.opts.anchor_idx+1].repeat(num_frames,1,1)
|
323 |
+
pose_t = poses
|
324 |
+
|
325 |
+
warped_images = []
|
326 |
+
masks = []
|
327 |
+
for i in tqdm(range(self.opts.video_length)):
|
328 |
+
warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, self.opts.mask,twice=False)
|
329 |
+
warped_images.append(warped_frame2)
|
330 |
+
masks.append(mask2)
|
331 |
+
cond_video = (torch.cat(warped_images)+1.)/2.
|
332 |
+
cond_masks = torch.cat(masks)
|
333 |
+
|
334 |
+
frames = F.interpolate(frames, size=self.opts.sample_size, mode='bilinear', align_corners=False)
|
335 |
+
cond_video = F.interpolate(cond_video, size=self.opts.sample_size, mode='bilinear', align_corners=False)
|
336 |
+
cond_masks = F.interpolate(cond_masks, size=self.opts.sample_size, mode='nearest')
|
337 |
+
save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(self.opts.save_dir,'input.mp4'),fps=self.opts.fps)
|
338 |
+
save_video(cond_video.permute(0,2,3,1), os.path.join(self.opts.save_dir,'render.mp4'),fps=self.opts.fps)
|
339 |
+
save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(self.opts.save_dir,'mask.mp4'),fps=self.opts.fps)
|
340 |
+
|
341 |
+
frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
|
342 |
+
frames_ref = frames[:,:,:10,:,:]
|
343 |
+
cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
|
344 |
+
cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
|
345 |
+
generator = torch.Generator(device=self.opts.device).manual_seed(seed)
|
346 |
+
|
347 |
+
# del self.depth_estimater
|
348 |
+
# del self.caption_processor
|
349 |
+
# del self.captioner
|
350 |
+
# gc.collect()
|
351 |
+
torch.cuda.empty_cache()
|
352 |
+
with torch.no_grad():
|
353 |
+
sample = self.pipeline(
|
354 |
+
prompt,
|
355 |
+
num_frames = self.opts.video_length,
|
356 |
+
negative_prompt = self.opts.negative_prompt,
|
357 |
+
height = self.opts.sample_size[0],
|
358 |
+
width = self.opts.sample_size[1],
|
359 |
+
generator = generator,
|
360 |
+
guidance_scale = self.opts.diffusion_guidance_scale,
|
361 |
+
num_inference_steps = steps,
|
362 |
+
video = cond_video,
|
363 |
+
mask_video = cond_masks,
|
364 |
+
reference = frames_ref,
|
365 |
+
).videos
|
366 |
+
save_video(sample[0].permute(1,2,3,0), os.path.join(self.opts.save_dir,'gen.mp4'), fps=self.opts.fps)
|
367 |
+
|
368 |
+
viz = True
|
369 |
+
if viz:
|
370 |
+
tensor_left = frames[0].to(self.opts.device)
|
371 |
+
tensor_right = sample[0].to(self.opts.device)
|
372 |
+
interval = torch.ones(3, 49, 384, 30).to(self.opts.device)
|
373 |
+
result = torch.cat((tensor_left, interval, tensor_right), dim=3)
|
374 |
+
result_reverse = torch.flip(result, dims=[1])
|
375 |
+
final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
|
376 |
+
save_video(final_result.permute(1,2,3,0), os.path.join(self.opts.save_dir,'viz.mp4'), fps=self.opts.fps*2)
|
377 |
+
return os.path.join(self.opts.save_dir,'viz.mp4')
|
docs/config_help.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Important configuration for [inference.py](../inference.py):
|
2 |
+
|
3 |
+
### 1. General configs
|
4 |
+
| Configuration | Default Value | Explanation |
|
5 |
+
|:----------------- |:--------------- |:-------------------------------------------------------- |
|
6 |
+
| `--video_path` | `None` | Input video file path |
|
7 |
+
| `--out_dir` | `./experiments/`| Output directory |
|
8 |
+
| `--device` | `cuda:0` | The device to use (e.g., CPU or GPU) |
|
9 |
+
| `--exp_name` | `None` | Experiment name, defaults to video file name |
|
10 |
+
| `--seed` | `43` | Random seed for reproducibility |
|
11 |
+
| `--video_length` | `49` | Length of the video frames (number of frames) |
|
12 |
+
| `--fps` | `10` | fps for saved video |
|
13 |
+
| `--stride` | `1` | Sampling stride for input video (frame interval) |
|
14 |
+
| `--server_name` | `None` | Server IP address for gradio |
|
15 |
+
### 2. Point cloud render configs
|
16 |
+
|
17 |
+
| Configuration | Default Value | Explanation |
|
18 |
+
|:----------------- |:--------------- |:-------------------------------------------------------- |
|
19 |
+
| `--radius_scale` | `1.0` | Scale factor for the spherical radius |
|
20 |
+
| `--camera` | `traj` | Camera pose type, either 'traj' or 'target' |
|
21 |
+
| `--mode` | `gradual` | Mode of operation, 'gradual', 'bullet', or 'direct' |
|
22 |
+
| `--mask` | `False` | Clean the point cloud data if true |
|
23 |
+
| `--target_pose` | `None` | Required for 'target' camera pose type, specifies a relative camera pose sequece (theta, phi, r, x, y). +theta (theta<50) rotates camera upward, +phi (phi<50) rotates camera to right, +r (r<0.6) moves camera forward, +x (x<4) pans the camera to right, +y (y<4) pans the camera upward |
|
24 |
+
| `--traj_txt` | `None` | Required for 'traj' camera pose type, a txt file specifying a complex camera trajectory ([examples](../test/trajs/loop1.txt)). The fist line is the theta sequence, the second line the phi sequence, and the last line the r sequence |
|
25 |
+
| `--near` | `0.0001` | Near clipping plane distance |
|
26 |
+
| `--far` | `10000.0` | Far clipping plane distance |
|
27 |
+
| `--anchor_idx` | `0` | One GT frame for anchor frame |
|
extern/depthcrafter/__init__.py
ADDED
File without changes
|
extern/depthcrafter/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (151 Bytes). View file
|
|
extern/depthcrafter/__pycache__/demo.cpython-310.pyc
ADDED
Binary file (3.86 kB). View file
|
|
extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc
ADDED
Binary file (7.88 kB). View file
|
|
extern/depthcrafter/__pycache__/infer.cpython-310.pyc
ADDED
Binary file (2.31 kB). View file
|
|
extern/depthcrafter/__pycache__/unet.cpython-310.pyc
ADDED
Binary file (2.62 kB). View file
|
|
extern/depthcrafter/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.29 kB). View file
|
|
extern/depthcrafter/depth_crafter_ppl.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
|
7 |
+
_resize_with_antialiasing,
|
8 |
+
StableVideoDiffusionPipelineOutput,
|
9 |
+
StableVideoDiffusionPipeline,
|
10 |
+
retrieve_timesteps,
|
11 |
+
)
|
12 |
+
from diffusers.utils import logging
|
13 |
+
from diffusers.utils.torch_utils import randn_tensor
|
14 |
+
|
15 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
16 |
+
|
17 |
+
|
18 |
+
class DepthCrafterPipeline(StableVideoDiffusionPipeline):
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def encode_video(
|
22 |
+
self,
|
23 |
+
video: torch.Tensor,
|
24 |
+
chunk_size: int = 14,
|
25 |
+
) -> torch.Tensor:
|
26 |
+
"""
|
27 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
28 |
+
:param chunk_size: the chunk size to encode video
|
29 |
+
:return: image_embeddings in shape of [b, 1024]
|
30 |
+
"""
|
31 |
+
|
32 |
+
video_224 = _resize_with_antialiasing(video.float(), (224, 224))
|
33 |
+
video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
|
34 |
+
|
35 |
+
embeddings = []
|
36 |
+
for i in range(0, video_224.shape[0], chunk_size):
|
37 |
+
tmp = self.feature_extractor(
|
38 |
+
images=video_224[i : i + chunk_size],
|
39 |
+
do_normalize=True,
|
40 |
+
do_center_crop=False,
|
41 |
+
do_resize=False,
|
42 |
+
do_rescale=False,
|
43 |
+
return_tensors="pt",
|
44 |
+
).pixel_values.to(video.device, dtype=video.dtype)
|
45 |
+
embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
|
46 |
+
|
47 |
+
embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
|
48 |
+
return embeddings
|
49 |
+
|
50 |
+
@torch.inference_mode()
|
51 |
+
def encode_vae_video(
|
52 |
+
self,
|
53 |
+
video: torch.Tensor,
|
54 |
+
chunk_size: int = 14,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
:param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
|
58 |
+
:param chunk_size: the chunk size to encode video
|
59 |
+
:return: vae latents in shape of [b, c, h, w]
|
60 |
+
"""
|
61 |
+
video_latents = []
|
62 |
+
for i in range(0, video.shape[0], chunk_size):
|
63 |
+
video_latents.append(
|
64 |
+
self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
|
65 |
+
)
|
66 |
+
video_latents = torch.cat(video_latents, dim=0)
|
67 |
+
return video_latents
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def check_inputs(video, height, width):
|
71 |
+
"""
|
72 |
+
:param video:
|
73 |
+
:param height:
|
74 |
+
:param width:
|
75 |
+
:return:
|
76 |
+
"""
|
77 |
+
if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
|
78 |
+
raise ValueError(
|
79 |
+
f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
|
80 |
+
)
|
81 |
+
|
82 |
+
if height % 8 != 0 or width % 8 != 0:
|
83 |
+
raise ValueError(
|
84 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
85 |
+
)
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def __call__(
|
89 |
+
self,
|
90 |
+
video: Union[np.ndarray, torch.Tensor],
|
91 |
+
height: int = 576,
|
92 |
+
width: int = 1024,
|
93 |
+
num_inference_steps: int = 25,
|
94 |
+
guidance_scale: float = 1.0,
|
95 |
+
window_size: Optional[int] = 110,
|
96 |
+
noise_aug_strength: float = 0.02,
|
97 |
+
decode_chunk_size: Optional[int] = None,
|
98 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
99 |
+
latents: Optional[torch.FloatTensor] = None,
|
100 |
+
output_type: Optional[str] = "pil",
|
101 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
102 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
103 |
+
return_dict: bool = True,
|
104 |
+
overlap: int = 25,
|
105 |
+
track_time: bool = False,
|
106 |
+
):
|
107 |
+
"""
|
108 |
+
:param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
|
109 |
+
:param height:
|
110 |
+
:param width:
|
111 |
+
:param num_inference_steps:
|
112 |
+
:param guidance_scale:
|
113 |
+
:param window_size: sliding window processing size
|
114 |
+
:param fps:
|
115 |
+
:param motion_bucket_id:
|
116 |
+
:param noise_aug_strength:
|
117 |
+
:param decode_chunk_size:
|
118 |
+
:param generator:
|
119 |
+
:param latents:
|
120 |
+
:param output_type:
|
121 |
+
:param callback_on_step_end:
|
122 |
+
:param callback_on_step_end_tensor_inputs:
|
123 |
+
:param return_dict:
|
124 |
+
:return:
|
125 |
+
"""
|
126 |
+
# 0. Default height and width to unet
|
127 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
128 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
129 |
+
num_frames = video.shape[0]
|
130 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
|
131 |
+
if num_frames <= window_size:
|
132 |
+
window_size = num_frames
|
133 |
+
overlap = 0
|
134 |
+
stride = window_size - overlap
|
135 |
+
|
136 |
+
# 1. Check inputs. Raise error if not correct
|
137 |
+
self.check_inputs(video, height, width)
|
138 |
+
|
139 |
+
# 2. Define call parameters
|
140 |
+
batch_size = 1
|
141 |
+
device = self._execution_device
|
142 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
143 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
144 |
+
# corresponds to doing no classifier free guidance.
|
145 |
+
self._guidance_scale = guidance_scale
|
146 |
+
|
147 |
+
# 3. Encode input video
|
148 |
+
if isinstance(video, np.ndarray):
|
149 |
+
video = torch.from_numpy(video.transpose(0, 3, 1, 2))
|
150 |
+
else:
|
151 |
+
assert isinstance(video, torch.Tensor)
|
152 |
+
video = video.to(device=device, dtype=self.dtype)
|
153 |
+
video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
|
154 |
+
|
155 |
+
if track_time:
|
156 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
157 |
+
encode_event = torch.cuda.Event(enable_timing=True)
|
158 |
+
denoise_event = torch.cuda.Event(enable_timing=True)
|
159 |
+
decode_event = torch.cuda.Event(enable_timing=True)
|
160 |
+
start_event.record()
|
161 |
+
|
162 |
+
video_embeddings = self.encode_video(
|
163 |
+
video, chunk_size=decode_chunk_size
|
164 |
+
).unsqueeze(
|
165 |
+
0
|
166 |
+
) # [1, t, 1024]
|
167 |
+
torch.cuda.empty_cache()
|
168 |
+
# 4. Encode input image using VAE
|
169 |
+
noise = randn_tensor(
|
170 |
+
video.shape, generator=generator, device=device, dtype=video.dtype
|
171 |
+
)
|
172 |
+
video = video + noise_aug_strength * noise # in [t, c, h, w]
|
173 |
+
|
174 |
+
# pdb.set_trace()
|
175 |
+
needs_upcasting = (
|
176 |
+
self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
177 |
+
)
|
178 |
+
if needs_upcasting:
|
179 |
+
self.vae.to(dtype=torch.float32)
|
180 |
+
|
181 |
+
video_latents = self.encode_vae_video(
|
182 |
+
video.to(self.vae.dtype),
|
183 |
+
chunk_size=decode_chunk_size,
|
184 |
+
).unsqueeze(
|
185 |
+
0
|
186 |
+
) # [1, t, c, h, w]
|
187 |
+
|
188 |
+
if track_time:
|
189 |
+
encode_event.record()
|
190 |
+
torch.cuda.synchronize()
|
191 |
+
elapsed_time_ms = start_event.elapsed_time(encode_event)
|
192 |
+
print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
|
193 |
+
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
|
196 |
+
# cast back to fp16 if needed
|
197 |
+
if needs_upcasting:
|
198 |
+
self.vae.to(dtype=torch.float16)
|
199 |
+
|
200 |
+
# 5. Get Added Time IDs
|
201 |
+
added_time_ids = self._get_add_time_ids(
|
202 |
+
7,
|
203 |
+
127,
|
204 |
+
noise_aug_strength,
|
205 |
+
video_embeddings.dtype,
|
206 |
+
batch_size,
|
207 |
+
1,
|
208 |
+
False,
|
209 |
+
) # [1 or 2, 3]
|
210 |
+
added_time_ids = added_time_ids.to(device)
|
211 |
+
|
212 |
+
# 6. Prepare timesteps
|
213 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
214 |
+
self.scheduler, num_inference_steps, device, None, None
|
215 |
+
)
|
216 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
217 |
+
self._num_timesteps = len(timesteps)
|
218 |
+
|
219 |
+
# 7. Prepare latent variables
|
220 |
+
num_channels_latents = self.unet.config.in_channels
|
221 |
+
latents_init = self.prepare_latents(
|
222 |
+
batch_size,
|
223 |
+
window_size,
|
224 |
+
num_channels_latents,
|
225 |
+
height,
|
226 |
+
width,
|
227 |
+
video_embeddings.dtype,
|
228 |
+
device,
|
229 |
+
generator,
|
230 |
+
latents,
|
231 |
+
) # [1, t, c, h, w]
|
232 |
+
latents_all = None
|
233 |
+
|
234 |
+
idx_start = 0
|
235 |
+
if overlap > 0:
|
236 |
+
weights = torch.linspace(0, 1, overlap, device=device)
|
237 |
+
weights = weights.view(1, overlap, 1, 1, 1)
|
238 |
+
else:
|
239 |
+
weights = None
|
240 |
+
|
241 |
+
torch.cuda.empty_cache()
|
242 |
+
|
243 |
+
# inference strategy for long videos
|
244 |
+
# two main strategies: 1. noise init from previous frame, 2. segments stitching
|
245 |
+
while idx_start < num_frames - overlap:
|
246 |
+
idx_end = min(idx_start + window_size, num_frames)
|
247 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
248 |
+
|
249 |
+
# 9. Denoising loop
|
250 |
+
latents = latents_init[:, : idx_end - idx_start].clone()
|
251 |
+
latents_init = torch.cat(
|
252 |
+
[latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
|
253 |
+
)
|
254 |
+
|
255 |
+
video_latents_current = video_latents[:, idx_start:idx_end]
|
256 |
+
video_embeddings_current = video_embeddings[:, idx_start:idx_end]
|
257 |
+
|
258 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
259 |
+
for i, t in enumerate(timesteps):
|
260 |
+
if latents_all is not None and i == 0:
|
261 |
+
latents[:, :overlap] = (
|
262 |
+
latents_all[:, -overlap:]
|
263 |
+
+ latents[:, :overlap]
|
264 |
+
/ self.scheduler.init_noise_sigma
|
265 |
+
* self.scheduler.sigmas[i]
|
266 |
+
)
|
267 |
+
|
268 |
+
latent_model_input = latents # [1, t, c, h, w]
|
269 |
+
latent_model_input = self.scheduler.scale_model_input(
|
270 |
+
latent_model_input, t
|
271 |
+
) # [1, t, c, h, w]
|
272 |
+
latent_model_input = torch.cat(
|
273 |
+
[latent_model_input, video_latents_current], dim=2
|
274 |
+
)
|
275 |
+
noise_pred = self.unet(
|
276 |
+
latent_model_input,
|
277 |
+
t,
|
278 |
+
encoder_hidden_states=video_embeddings_current,
|
279 |
+
added_time_ids=added_time_ids,
|
280 |
+
return_dict=False,
|
281 |
+
)[0]
|
282 |
+
# perform guidance
|
283 |
+
if self.do_classifier_free_guidance:
|
284 |
+
latent_model_input = latents
|
285 |
+
latent_model_input = self.scheduler.scale_model_input(
|
286 |
+
latent_model_input, t
|
287 |
+
)
|
288 |
+
latent_model_input = torch.cat(
|
289 |
+
[latent_model_input, torch.zeros_like(latent_model_input)],
|
290 |
+
dim=2,
|
291 |
+
)
|
292 |
+
noise_pred_uncond = self.unet(
|
293 |
+
latent_model_input,
|
294 |
+
t,
|
295 |
+
encoder_hidden_states=torch.zeros_like(
|
296 |
+
video_embeddings_current
|
297 |
+
),
|
298 |
+
added_time_ids=added_time_ids,
|
299 |
+
return_dict=False,
|
300 |
+
)[0]
|
301 |
+
|
302 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
303 |
+
noise_pred - noise_pred_uncond
|
304 |
+
)
|
305 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
306 |
+
|
307 |
+
if callback_on_step_end is not None:
|
308 |
+
callback_kwargs = {}
|
309 |
+
for k in callback_on_step_end_tensor_inputs:
|
310 |
+
callback_kwargs[k] = locals()[k]
|
311 |
+
callback_outputs = callback_on_step_end(
|
312 |
+
self, i, t, callback_kwargs
|
313 |
+
)
|
314 |
+
|
315 |
+
latents = callback_outputs.pop("latents", latents)
|
316 |
+
|
317 |
+
if i == len(timesteps) - 1 or (
|
318 |
+
(i + 1) > num_warmup_steps
|
319 |
+
and (i + 1) % self.scheduler.order == 0
|
320 |
+
):
|
321 |
+
progress_bar.update()
|
322 |
+
|
323 |
+
if latents_all is None:
|
324 |
+
latents_all = latents.clone()
|
325 |
+
else:
|
326 |
+
assert weights is not None
|
327 |
+
# latents_all[:, -overlap:] = (
|
328 |
+
# latents[:, :overlap] + latents_all[:, -overlap:]
|
329 |
+
# ) / 2.0
|
330 |
+
latents_all[:, -overlap:] = latents[
|
331 |
+
:, :overlap
|
332 |
+
] * weights + latents_all[:, -overlap:] * (1 - weights)
|
333 |
+
latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
|
334 |
+
|
335 |
+
idx_start += stride
|
336 |
+
|
337 |
+
if track_time:
|
338 |
+
denoise_event.record()
|
339 |
+
torch.cuda.synchronize()
|
340 |
+
elapsed_time_ms = encode_event.elapsed_time(denoise_event)
|
341 |
+
print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
|
342 |
+
|
343 |
+
if not output_type == "latent":
|
344 |
+
# cast back to fp16 if needed
|
345 |
+
if needs_upcasting:
|
346 |
+
self.vae.to(dtype=torch.float16)
|
347 |
+
frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
|
348 |
+
|
349 |
+
if track_time:
|
350 |
+
decode_event.record()
|
351 |
+
torch.cuda.synchronize()
|
352 |
+
elapsed_time_ms = denoise_event.elapsed_time(decode_event)
|
353 |
+
print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
|
354 |
+
|
355 |
+
frames = self.video_processor.postprocess_video(
|
356 |
+
video=frames, output_type=output_type
|
357 |
+
)
|
358 |
+
else:
|
359 |
+
frames = latents_all
|
360 |
+
|
361 |
+
self.maybe_free_model_hooks()
|
362 |
+
|
363 |
+
if not return_dict:
|
364 |
+
return frames
|
365 |
+
|
366 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
extern/depthcrafter/infer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from diffusers.training_utils import set_seed
|
7 |
+
from extern.depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
|
8 |
+
from extern.depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
|
9 |
+
|
10 |
+
|
11 |
+
class DepthCrafterDemo:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
unet_path: str,
|
15 |
+
pre_train_path: str,
|
16 |
+
cpu_offload: str = "model",
|
17 |
+
device: str = "cuda:0"
|
18 |
+
):
|
19 |
+
unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
|
20 |
+
unet_path,
|
21 |
+
low_cpu_mem_usage=True,
|
22 |
+
torch_dtype=torch.float16,
|
23 |
+
)
|
24 |
+
# load weights of other components from the provided checkpoint
|
25 |
+
self.pipe = DepthCrafterPipeline.from_pretrained(
|
26 |
+
pre_train_path,
|
27 |
+
unet=unet,
|
28 |
+
torch_dtype=torch.float16,
|
29 |
+
variant="fp16",
|
30 |
+
)
|
31 |
+
|
32 |
+
# for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
|
33 |
+
if cpu_offload is not None:
|
34 |
+
if cpu_offload == "sequential":
|
35 |
+
# This will slow, but save more memory
|
36 |
+
self.pipe.enable_sequential_cpu_offload()
|
37 |
+
elif cpu_offload == "model":
|
38 |
+
self.pipe.enable_model_cpu_offload()
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
|
41 |
+
else:
|
42 |
+
self.pipe.to(device)
|
43 |
+
# enable attention slicing and xformers memory efficient attention
|
44 |
+
try:
|
45 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
46 |
+
except Exception as e:
|
47 |
+
print(e)
|
48 |
+
print("Xformers is not enabled")
|
49 |
+
self.pipe.enable_attention_slicing()
|
50 |
+
|
51 |
+
def infer(
|
52 |
+
self,
|
53 |
+
frames,
|
54 |
+
near,
|
55 |
+
far,
|
56 |
+
num_denoising_steps: int,
|
57 |
+
guidance_scale: float,
|
58 |
+
window_size: int = 110,
|
59 |
+
overlap: int = 25,
|
60 |
+
seed: int = 42,
|
61 |
+
track_time: bool = True,
|
62 |
+
):
|
63 |
+
set_seed(seed)
|
64 |
+
|
65 |
+
# inference the depth map using the DepthCrafter pipeline
|
66 |
+
with torch.inference_mode():
|
67 |
+
res = self.pipe(
|
68 |
+
frames,
|
69 |
+
height=frames.shape[1],
|
70 |
+
width=frames.shape[2],
|
71 |
+
output_type="np",
|
72 |
+
guidance_scale=guidance_scale,
|
73 |
+
num_inference_steps=num_denoising_steps,
|
74 |
+
window_size=window_size,
|
75 |
+
overlap=overlap,
|
76 |
+
track_time=track_time,
|
77 |
+
).frames[0]
|
78 |
+
# convert the three-channel output to a single channel depth map
|
79 |
+
res = res.sum(-1) / res.shape[-1]
|
80 |
+
# normalize the depth map to [0, 1] across the whole video
|
81 |
+
depths = (res - res.min()) / (res.max() - res.min())
|
82 |
+
# visualize the depth map and save the results
|
83 |
+
# vis = vis_sequence_depth(res)
|
84 |
+
# save the depth map and visualization with the target FPS
|
85 |
+
depths = torch.from_numpy(depths).unsqueeze(1) # 49 576 1024 ->
|
86 |
+
depths *= 3900 # compatible with da output
|
87 |
+
depths[depths < 1e-5] = 1e-5
|
88 |
+
depths = 10000. / depths
|
89 |
+
depths = depths.clip(near, far)
|
90 |
+
|
91 |
+
return depths
|
extern/depthcrafter/unet.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
5 |
+
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
|
6 |
+
|
7 |
+
|
8 |
+
class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
|
9 |
+
UNetSpatioTemporalConditionModel
|
10 |
+
):
|
11 |
+
|
12 |
+
def forward(
|
13 |
+
self,
|
14 |
+
sample: torch.Tensor,
|
15 |
+
timestep: Union[torch.Tensor, float, int],
|
16 |
+
encoder_hidden_states: torch.Tensor,
|
17 |
+
added_time_ids: torch.Tensor,
|
18 |
+
return_dict: bool = True,
|
19 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
20 |
+
|
21 |
+
# 1. time
|
22 |
+
timesteps = timestep
|
23 |
+
if not torch.is_tensor(timesteps):
|
24 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
25 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
26 |
+
is_mps = sample.device.type == "mps"
|
27 |
+
if isinstance(timestep, float):
|
28 |
+
dtype = torch.float32 if is_mps else torch.float64
|
29 |
+
else:
|
30 |
+
dtype = torch.int32 if is_mps else torch.int64
|
31 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
32 |
+
elif len(timesteps.shape) == 0:
|
33 |
+
timesteps = timesteps[None].to(sample.device)
|
34 |
+
|
35 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
36 |
+
batch_size, num_frames = sample.shape[:2]
|
37 |
+
timesteps = timesteps.expand(batch_size)
|
38 |
+
|
39 |
+
t_emb = self.time_proj(timesteps)
|
40 |
+
|
41 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
42 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
43 |
+
# there might be better ways to encapsulate this.
|
44 |
+
t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
|
45 |
+
|
46 |
+
emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
|
47 |
+
|
48 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
49 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
50 |
+
time_embeds = time_embeds.to(emb.dtype)
|
51 |
+
aug_emb = self.add_embedding(time_embeds)
|
52 |
+
emb = emb + aug_emb
|
53 |
+
|
54 |
+
# Flatten the batch and frames dimensions
|
55 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
56 |
+
sample = sample.flatten(0, 1)
|
57 |
+
# Repeat the embeddings num_video_frames times
|
58 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
59 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
60 |
+
# encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
|
61 |
+
encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
|
62 |
+
|
63 |
+
# 2. pre-process
|
64 |
+
sample = sample.to(dtype=self.conv_in.weight.dtype)
|
65 |
+
assert sample.dtype == self.conv_in.weight.dtype, (
|
66 |
+
f"sample.dtype: {sample.dtype}, "
|
67 |
+
f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
|
68 |
+
)
|
69 |
+
sample = self.conv_in(sample)
|
70 |
+
|
71 |
+
image_only_indicator = torch.zeros(
|
72 |
+
batch_size, num_frames, dtype=sample.dtype, device=sample.device
|
73 |
+
)
|
74 |
+
|
75 |
+
down_block_res_samples = (sample,)
|
76 |
+
for downsample_block in self.down_blocks:
|
77 |
+
if (
|
78 |
+
hasattr(downsample_block, "has_cross_attention")
|
79 |
+
and downsample_block.has_cross_attention
|
80 |
+
):
|
81 |
+
sample, res_samples = downsample_block(
|
82 |
+
hidden_states=sample,
|
83 |
+
temb=emb,
|
84 |
+
encoder_hidden_states=encoder_hidden_states,
|
85 |
+
image_only_indicator=image_only_indicator,
|
86 |
+
)
|
87 |
+
|
88 |
+
else:
|
89 |
+
sample, res_samples = downsample_block(
|
90 |
+
hidden_states=sample,
|
91 |
+
temb=emb,
|
92 |
+
image_only_indicator=image_only_indicator,
|
93 |
+
)
|
94 |
+
|
95 |
+
down_block_res_samples += res_samples
|
96 |
+
|
97 |
+
# 4. mid
|
98 |
+
sample = self.mid_block(
|
99 |
+
hidden_states=sample,
|
100 |
+
temb=emb,
|
101 |
+
encoder_hidden_states=encoder_hidden_states,
|
102 |
+
image_only_indicator=image_only_indicator,
|
103 |
+
)
|
104 |
+
|
105 |
+
# 5. up
|
106 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
107 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
108 |
+
down_block_res_samples = down_block_res_samples[
|
109 |
+
: -len(upsample_block.resnets)
|
110 |
+
]
|
111 |
+
|
112 |
+
if (
|
113 |
+
hasattr(upsample_block, "has_cross_attention")
|
114 |
+
and upsample_block.has_cross_attention
|
115 |
+
):
|
116 |
+
sample = upsample_block(
|
117 |
+
hidden_states=sample,
|
118 |
+
res_hidden_states_tuple=res_samples,
|
119 |
+
temb=emb,
|
120 |
+
encoder_hidden_states=encoder_hidden_states,
|
121 |
+
image_only_indicator=image_only_indicator,
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
sample = upsample_block(
|
125 |
+
hidden_states=sample,
|
126 |
+
res_hidden_states_tuple=res_samples,
|
127 |
+
temb=emb,
|
128 |
+
image_only_indicator=image_only_indicator,
|
129 |
+
)
|
130 |
+
|
131 |
+
# 6. post-process
|
132 |
+
sample = self.conv_norm_out(sample)
|
133 |
+
sample = self.conv_act(sample)
|
134 |
+
sample = self.conv_out(sample)
|
135 |
+
|
136 |
+
# 7. Reshape back to original shape
|
137 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
138 |
+
|
139 |
+
if not return_dict:
|
140 |
+
return (sample,)
|
141 |
+
|
142 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc
ADDED
Binary file (3.64 kB). View file
|
|
extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc
ADDED
Binary file (2.76 kB). View file
|
|
extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc
ADDED
Binary file (1.52 kB). View file
|
|
extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc
ADDED
Binary file (4.66 kB). View file
|
|
extern/video_depth_anything/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
extern/video_depth_anything/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (415 Bytes). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (2.38 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc
ADDED
Binary file (7.99 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc
ADDED
Binary file (1.22 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc
ADDED
Binary file (1.02 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc
ADDED
Binary file (1.21 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc
ADDED
Binary file (2.66 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc
ADDED
Binary file (2.01 kB). View file
|
|
extern/video_depth_anything/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
extern/video_depth_anything/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
extern/video_depth_anything/dinov2_layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
extern/video_depth_anything/dinov2_layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
extern/video_depth_anything/dinov2_layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
extern/video_depth_anything/dinov2_layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
extern/video_depth_anything/dinov2_layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
extern/video_depth_anything/dpt.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
19 |
+
|
20 |
+
|
21 |
+
def _make_fusion_block(features, use_bn, size=None):
|
22 |
+
return FeatureFusionBlock(
|
23 |
+
features,
|
24 |
+
nn.ReLU(False),
|
25 |
+
deconv=False,
|
26 |
+
bn=use_bn,
|
27 |
+
expand=False,
|
28 |
+
align_corners=True,
|
29 |
+
size=size,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class ConvBlock(nn.Module):
|
34 |
+
def __init__(self, in_feature, out_feature):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.conv_block = nn.Sequential(
|
38 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
39 |
+
nn.BatchNorm2d(out_feature),
|
40 |
+
nn.ReLU(True)
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.conv_block(x)
|
45 |
+
|
46 |
+
|
47 |
+
class DPTHead(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_channels,
|
51 |
+
features=256,
|
52 |
+
use_bn=False,
|
53 |
+
out_channels=[256, 512, 1024, 1024],
|
54 |
+
use_clstoken=False
|
55 |
+
):
|
56 |
+
super(DPTHead, self).__init__()
|
57 |
+
|
58 |
+
self.use_clstoken = use_clstoken
|
59 |
+
|
60 |
+
self.projects = nn.ModuleList([
|
61 |
+
nn.Conv2d(
|
62 |
+
in_channels=in_channels,
|
63 |
+
out_channels=out_channel,
|
64 |
+
kernel_size=1,
|
65 |
+
stride=1,
|
66 |
+
padding=0,
|
67 |
+
) for out_channel in out_channels
|
68 |
+
])
|
69 |
+
|
70 |
+
self.resize_layers = nn.ModuleList([
|
71 |
+
nn.ConvTranspose2d(
|
72 |
+
in_channels=out_channels[0],
|
73 |
+
out_channels=out_channels[0],
|
74 |
+
kernel_size=4,
|
75 |
+
stride=4,
|
76 |
+
padding=0),
|
77 |
+
nn.ConvTranspose2d(
|
78 |
+
in_channels=out_channels[1],
|
79 |
+
out_channels=out_channels[1],
|
80 |
+
kernel_size=2,
|
81 |
+
stride=2,
|
82 |
+
padding=0),
|
83 |
+
nn.Identity(),
|
84 |
+
nn.Conv2d(
|
85 |
+
in_channels=out_channels[3],
|
86 |
+
out_channels=out_channels[3],
|
87 |
+
kernel_size=3,
|
88 |
+
stride=2,
|
89 |
+
padding=1)
|
90 |
+
])
|
91 |
+
|
92 |
+
if use_clstoken:
|
93 |
+
self.readout_projects = nn.ModuleList()
|
94 |
+
for _ in range(len(self.projects)):
|
95 |
+
self.readout_projects.append(
|
96 |
+
nn.Sequential(
|
97 |
+
nn.Linear(2 * in_channels, in_channels),
|
98 |
+
nn.GELU()))
|
99 |
+
|
100 |
+
self.scratch = _make_scratch(
|
101 |
+
out_channels,
|
102 |
+
features,
|
103 |
+
groups=1,
|
104 |
+
expand=False,
|
105 |
+
)
|
106 |
+
|
107 |
+
self.scratch.stem_transpose = None
|
108 |
+
|
109 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
110 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
111 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
112 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
113 |
+
|
114 |
+
head_features_1 = features
|
115 |
+
head_features_2 = 32
|
116 |
+
|
117 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
118 |
+
self.scratch.output_conv2 = nn.Sequential(
|
119 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
120 |
+
nn.ReLU(True),
|
121 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
122 |
+
nn.ReLU(True),
|
123 |
+
nn.Identity(),
|
124 |
+
)
|
125 |
+
|
126 |
+
def forward(self, out_features, patch_h, patch_w):
|
127 |
+
out = []
|
128 |
+
for i, x in enumerate(out_features):
|
129 |
+
if self.use_clstoken:
|
130 |
+
x, cls_token = x[0], x[1]
|
131 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
132 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
133 |
+
else:
|
134 |
+
x = x[0]
|
135 |
+
|
136 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
137 |
+
|
138 |
+
x = self.projects[i](x)
|
139 |
+
x = self.resize_layers[i](x)
|
140 |
+
|
141 |
+
out.append(x)
|
142 |
+
|
143 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
144 |
+
|
145 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
146 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
147 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
148 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
149 |
+
|
150 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
151 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
152 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
153 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
154 |
+
|
155 |
+
out = self.scratch.output_conv1(path_1)
|
156 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
157 |
+
out = self.scratch.output_conv2(out)
|
158 |
+
|
159 |
+
return out
|
160 |
+
|
extern/video_depth_anything/dpt_temporal.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.nn as nn
|
17 |
+
from .dpt import DPTHead
|
18 |
+
from .motion_module.motion_module import TemporalModule
|
19 |
+
from easydict import EasyDict
|
20 |
+
|
21 |
+
|
22 |
+
class DPTHeadTemporal(DPTHead):
|
23 |
+
def __init__(self,
|
24 |
+
in_channels,
|
25 |
+
features=256,
|
26 |
+
use_bn=False,
|
27 |
+
out_channels=[256, 512, 1024, 1024],
|
28 |
+
use_clstoken=False,
|
29 |
+
num_frames=32,
|
30 |
+
pe='ape'
|
31 |
+
):
|
32 |
+
super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
|
33 |
+
|
34 |
+
assert num_frames > 0
|
35 |
+
motion_module_kwargs = EasyDict(num_attention_heads = 8,
|
36 |
+
num_transformer_block = 1,
|
37 |
+
num_attention_blocks = 2,
|
38 |
+
temporal_max_len = num_frames,
|
39 |
+
zero_initialize = True,
|
40 |
+
pos_embedding_type = pe)
|
41 |
+
|
42 |
+
self.motion_modules = nn.ModuleList([
|
43 |
+
TemporalModule(in_channels=out_channels[2],
|
44 |
+
**motion_module_kwargs),
|
45 |
+
TemporalModule(in_channels=out_channels[3],
|
46 |
+
**motion_module_kwargs),
|
47 |
+
TemporalModule(in_channels=features,
|
48 |
+
**motion_module_kwargs),
|
49 |
+
TemporalModule(in_channels=features,
|
50 |
+
**motion_module_kwargs)
|
51 |
+
])
|
52 |
+
|
53 |
+
def forward(self, out_features, patch_h, patch_w, frame_length):
|
54 |
+
out = []
|
55 |
+
for i, x in enumerate(out_features):
|
56 |
+
if self.use_clstoken:
|
57 |
+
x, cls_token = x[0], x[1]
|
58 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
59 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
60 |
+
else:
|
61 |
+
x = x[0]
|
62 |
+
|
63 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
|
64 |
+
|
65 |
+
B, T = x.shape[0] // frame_length, frame_length
|
66 |
+
x = self.projects[i](x)
|
67 |
+
x = self.resize_layers[i](x)
|
68 |
+
|
69 |
+
out.append(x)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
72 |
+
|
73 |
+
B, T = layer_1.shape[0] // frame_length, frame_length
|
74 |
+
|
75 |
+
layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
76 |
+
layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
77 |
+
|
78 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
79 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
80 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
81 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
82 |
+
|
83 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
84 |
+
path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
85 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
86 |
+
path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
|
87 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
88 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
89 |
+
|
90 |
+
out = self.scratch.output_conv1(path_1)
|
91 |
+
out = F.interpolate(
|
92 |
+
out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
|
93 |
+
)
|
94 |
+
out = self.scratch.output_conv2(out)
|
95 |
+
|
96 |
+
return out
|
extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (12.1 kB). View file
|
|
extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc
ADDED
Binary file (7.39 kB). View file
|
|
extern/video_depth_anything/motion_module/attention.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional, Tuple
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
try:
|
21 |
+
import xformers
|
22 |
+
import xformers.ops
|
23 |
+
|
24 |
+
XFORMERS_AVAILABLE = True
|
25 |
+
except ImportError:
|
26 |
+
print("xFormers not available")
|
27 |
+
XFORMERS_AVAILABLE = False
|
28 |
+
|
29 |
+
|
30 |
+
class CrossAttention(nn.Module):
|
31 |
+
r"""
|
32 |
+
A cross attention layer.
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
query_dim (`int`): The number of channels in the query.
|
36 |
+
cross_attention_dim (`int`, *optional*):
|
37 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
38 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
39 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
40 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
41 |
+
bias (`bool`, *optional*, defaults to False):
|
42 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
query_dim: int,
|
48 |
+
cross_attention_dim: Optional[int] = None,
|
49 |
+
heads: int = 8,
|
50 |
+
dim_head: int = 64,
|
51 |
+
dropout: float = 0.0,
|
52 |
+
bias=False,
|
53 |
+
upcast_attention: bool = False,
|
54 |
+
upcast_softmax: bool = False,
|
55 |
+
added_kv_proj_dim: Optional[int] = None,
|
56 |
+
norm_num_groups: Optional[int] = None,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
inner_dim = dim_head * heads
|
60 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
61 |
+
self.upcast_attention = upcast_attention
|
62 |
+
self.upcast_softmax = upcast_softmax
|
63 |
+
self.upcast_efficient_attention = False
|
64 |
+
|
65 |
+
self.scale = dim_head**-0.5
|
66 |
+
|
67 |
+
self.heads = heads
|
68 |
+
# for slice_size > 0 the attention score computation
|
69 |
+
# is split across the batch axis to save memory
|
70 |
+
# You can set slice_size with `set_attention_slice`
|
71 |
+
self.sliceable_head_dim = heads
|
72 |
+
self._slice_size = None
|
73 |
+
self._use_memory_efficient_attention_xformers = False
|
74 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
75 |
+
|
76 |
+
if norm_num_groups is not None:
|
77 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
78 |
+
else:
|
79 |
+
self.group_norm = None
|
80 |
+
|
81 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
82 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
83 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
84 |
+
|
85 |
+
if self.added_kv_proj_dim is not None:
|
86 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
87 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
88 |
+
|
89 |
+
self.to_out = nn.ModuleList([])
|
90 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
91 |
+
self.to_out.append(nn.Dropout(dropout))
|
92 |
+
|
93 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
94 |
+
batch_size, seq_len, dim = tensor.shape
|
95 |
+
head_size = self.heads
|
96 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
|
97 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
|
98 |
+
return tensor
|
99 |
+
|
100 |
+
def reshape_heads_to_4d(self, tensor):
|
101 |
+
batch_size, seq_len, dim = tensor.shape
|
102 |
+
head_size = self.heads
|
103 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
|
104 |
+
return tensor
|
105 |
+
|
106 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
107 |
+
batch_size, seq_len, dim = tensor.shape
|
108 |
+
head_size = self.heads
|
109 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
|
110 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
|
111 |
+
return tensor
|
112 |
+
|
113 |
+
def reshape_4d_to_heads(self, tensor):
|
114 |
+
batch_size, seq_len, head_size, dim = tensor.shape
|
115 |
+
head_size = self.heads
|
116 |
+
tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
|
117 |
+
return tensor
|
118 |
+
|
119 |
+
def set_attention_slice(self, slice_size):
|
120 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
121 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
122 |
+
|
123 |
+
self._slice_size = slice_size
|
124 |
+
|
125 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
126 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
127 |
+
|
128 |
+
encoder_hidden_states = encoder_hidden_states
|
129 |
+
|
130 |
+
if self.group_norm is not None:
|
131 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
132 |
+
|
133 |
+
query = self.to_q(hidden_states)
|
134 |
+
dim = query.shape[-1]
|
135 |
+
query = self.reshape_heads_to_batch_dim(query)
|
136 |
+
|
137 |
+
if self.added_kv_proj_dim is not None:
|
138 |
+
key = self.to_k(hidden_states)
|
139 |
+
value = self.to_v(hidden_states)
|
140 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
141 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
142 |
+
|
143 |
+
key = self.reshape_heads_to_batch_dim(key)
|
144 |
+
value = self.reshape_heads_to_batch_dim(value)
|
145 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
146 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
147 |
+
|
148 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
149 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
150 |
+
else:
|
151 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
152 |
+
key = self.to_k(encoder_hidden_states)
|
153 |
+
value = self.to_v(encoder_hidden_states)
|
154 |
+
|
155 |
+
key = self.reshape_heads_to_batch_dim(key)
|
156 |
+
value = self.reshape_heads_to_batch_dim(value)
|
157 |
+
|
158 |
+
if attention_mask is not None:
|
159 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
160 |
+
target_length = query.shape[1]
|
161 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
162 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
163 |
+
|
164 |
+
# attention, what we cannot get enough of
|
165 |
+
if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:
|
166 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
167 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
168 |
+
hidden_states = hidden_states.to(query.dtype)
|
169 |
+
else:
|
170 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
171 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
172 |
+
else:
|
173 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
174 |
+
|
175 |
+
# linear proj
|
176 |
+
hidden_states = self.to_out[0](hidden_states)
|
177 |
+
|
178 |
+
# dropout
|
179 |
+
hidden_states = self.to_out[1](hidden_states)
|
180 |
+
return hidden_states
|
181 |
+
|
182 |
+
def _attention(self, query, key, value, attention_mask=None):
|
183 |
+
if self.upcast_attention:
|
184 |
+
query = query.float()
|
185 |
+
key = key.float()
|
186 |
+
|
187 |
+
attention_scores = torch.baddbmm(
|
188 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
189 |
+
query,
|
190 |
+
key.transpose(-1, -2),
|
191 |
+
beta=0,
|
192 |
+
alpha=self.scale,
|
193 |
+
)
|
194 |
+
|
195 |
+
if attention_mask is not None:
|
196 |
+
attention_scores = attention_scores + attention_mask
|
197 |
+
|
198 |
+
if self.upcast_softmax:
|
199 |
+
attention_scores = attention_scores.float()
|
200 |
+
|
201 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
202 |
+
|
203 |
+
# cast back to the original dtype
|
204 |
+
attention_probs = attention_probs.to(value.dtype)
|
205 |
+
|
206 |
+
# compute attention output
|
207 |
+
hidden_states = torch.bmm(attention_probs, value)
|
208 |
+
|
209 |
+
# reshape hidden_states
|
210 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
211 |
+
return hidden_states
|
212 |
+
|
213 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
214 |
+
batch_size_attention = query.shape[0]
|
215 |
+
hidden_states = torch.zeros(
|
216 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
217 |
+
)
|
218 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
219 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
220 |
+
start_idx = i * slice_size
|
221 |
+
end_idx = (i + 1) * slice_size
|
222 |
+
|
223 |
+
query_slice = query[start_idx:end_idx]
|
224 |
+
key_slice = key[start_idx:end_idx]
|
225 |
+
|
226 |
+
if self.upcast_attention:
|
227 |
+
query_slice = query_slice.float()
|
228 |
+
key_slice = key_slice.float()
|
229 |
+
|
230 |
+
attn_slice = torch.baddbmm(
|
231 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
232 |
+
query_slice,
|
233 |
+
key_slice.transpose(-1, -2),
|
234 |
+
beta=0,
|
235 |
+
alpha=self.scale,
|
236 |
+
)
|
237 |
+
|
238 |
+
if attention_mask is not None:
|
239 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
240 |
+
|
241 |
+
if self.upcast_softmax:
|
242 |
+
attn_slice = attn_slice.float()
|
243 |
+
|
244 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
245 |
+
|
246 |
+
# cast back to the original dtype
|
247 |
+
attn_slice = attn_slice.to(value.dtype)
|
248 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
249 |
+
|
250 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
251 |
+
|
252 |
+
# reshape hidden_states
|
253 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
254 |
+
return hidden_states
|
255 |
+
|
256 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
257 |
+
if self.upcast_efficient_attention:
|
258 |
+
org_dtype = query.dtype
|
259 |
+
query = query.float()
|
260 |
+
key = key.float()
|
261 |
+
value = value.float()
|
262 |
+
if attention_mask is not None:
|
263 |
+
attention_mask = attention_mask.float()
|
264 |
+
hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
|
265 |
+
|
266 |
+
if self.upcast_efficient_attention:
|
267 |
+
hidden_states = hidden_states.to(org_dtype)
|
268 |
+
|
269 |
+
hidden_states = self.reshape_4d_to_heads(hidden_states)
|
270 |
+
return hidden_states
|
271 |
+
|
272 |
+
# print("Errror: no xformers")
|
273 |
+
# raise NotImplementedError
|
274 |
+
|
275 |
+
def _memory_efficient_attention_split(self, query, key, value, attention_mask):
|
276 |
+
batch_size = query.shape[0]
|
277 |
+
max_batch_size = 65535
|
278 |
+
num_batches = (batch_size + max_batch_size - 1) // max_batch_size
|
279 |
+
results = []
|
280 |
+
for i in range(num_batches):
|
281 |
+
start_idx = i * max_batch_size
|
282 |
+
end_idx = min((i + 1) * max_batch_size, batch_size)
|
283 |
+
query_batch = query[start_idx:end_idx]
|
284 |
+
key_batch = key[start_idx:end_idx]
|
285 |
+
value_batch = value[start_idx:end_idx]
|
286 |
+
if attention_mask is not None:
|
287 |
+
attention_mask_batch = attention_mask[start_idx:end_idx]
|
288 |
+
else:
|
289 |
+
attention_mask_batch = None
|
290 |
+
result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
|
291 |
+
results.append(result)
|
292 |
+
full_result = torch.cat(results, dim=0)
|
293 |
+
return full_result
|
294 |
+
|
295 |
+
|
296 |
+
class FeedForward(nn.Module):
|
297 |
+
r"""
|
298 |
+
A feed-forward layer.
|
299 |
+
|
300 |
+
Parameters:
|
301 |
+
dim (`int`): The number of channels in the input.
|
302 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
303 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
304 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
305 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(
|
309 |
+
self,
|
310 |
+
dim: int,
|
311 |
+
dim_out: Optional[int] = None,
|
312 |
+
mult: int = 4,
|
313 |
+
dropout: float = 0.0,
|
314 |
+
activation_fn: str = "geglu",
|
315 |
+
):
|
316 |
+
super().__init__()
|
317 |
+
inner_dim = int(dim * mult)
|
318 |
+
dim_out = dim_out if dim_out is not None else dim
|
319 |
+
|
320 |
+
if activation_fn == "gelu":
|
321 |
+
act_fn = GELU(dim, inner_dim)
|
322 |
+
elif activation_fn == "geglu":
|
323 |
+
act_fn = GEGLU(dim, inner_dim)
|
324 |
+
elif activation_fn == "geglu-approximate":
|
325 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
326 |
+
|
327 |
+
self.net = nn.ModuleList([])
|
328 |
+
# project in
|
329 |
+
self.net.append(act_fn)
|
330 |
+
# project dropout
|
331 |
+
self.net.append(nn.Dropout(dropout))
|
332 |
+
# project out
|
333 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
334 |
+
|
335 |
+
def forward(self, hidden_states):
|
336 |
+
for module in self.net:
|
337 |
+
hidden_states = module(hidden_states)
|
338 |
+
return hidden_states
|
339 |
+
|
340 |
+
|
341 |
+
class GELU(nn.Module):
|
342 |
+
r"""
|
343 |
+
GELU activation function
|
344 |
+
"""
|
345 |
+
|
346 |
+
def __init__(self, dim_in: int, dim_out: int):
|
347 |
+
super().__init__()
|
348 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
349 |
+
|
350 |
+
def gelu(self, gate):
|
351 |
+
if gate.device.type != "mps":
|
352 |
+
return F.gelu(gate)
|
353 |
+
# mps: gelu is not implemented for float16
|
354 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
355 |
+
|
356 |
+
def forward(self, hidden_states):
|
357 |
+
hidden_states = self.proj(hidden_states)
|
358 |
+
hidden_states = self.gelu(hidden_states)
|
359 |
+
return hidden_states
|
360 |
+
|
361 |
+
|
362 |
+
# feedforward
|
363 |
+
class GEGLU(nn.Module):
|
364 |
+
r"""
|
365 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
366 |
+
|
367 |
+
Parameters:
|
368 |
+
dim_in (`int`): The number of channels in the input.
|
369 |
+
dim_out (`int`): The number of channels in the output.
|
370 |
+
"""
|
371 |
+
|
372 |
+
def __init__(self, dim_in: int, dim_out: int):
|
373 |
+
super().__init__()
|
374 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
375 |
+
|
376 |
+
def gelu(self, gate):
|
377 |
+
if gate.device.type != "mps":
|
378 |
+
return F.gelu(gate)
|
379 |
+
# mps: gelu is not implemented for float16
|
380 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
381 |
+
|
382 |
+
def forward(self, hidden_states):
|
383 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
384 |
+
return hidden_states * self.gelu(gate)
|
385 |
+
|
386 |
+
|
387 |
+
class ApproximateGELU(nn.Module):
|
388 |
+
"""
|
389 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
390 |
+
|
391 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
392 |
+
"""
|
393 |
+
|
394 |
+
def __init__(self, dim_in: int, dim_out: int):
|
395 |
+
super().__init__()
|
396 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
397 |
+
|
398 |
+
def forward(self, x):
|
399 |
+
x = self.proj(x)
|
400 |
+
return x * torch.sigmoid(1.702 * x)
|
401 |
+
|
402 |
+
|
403 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
404 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
405 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
406 |
+
freqs = torch.outer(t, freqs)
|
407 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
408 |
+
return freqs_cis
|
409 |
+
|
410 |
+
|
411 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
412 |
+
ndim = x.ndim
|
413 |
+
assert 0 <= 1 < ndim
|
414 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
415 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
416 |
+
return freqs_cis.view(*shape)
|
417 |
+
|
418 |
+
|
419 |
+
def apply_rotary_emb(
|
420 |
+
xq: torch.Tensor,
|
421 |
+
xk: torch.Tensor,
|
422 |
+
freqs_cis: torch.Tensor,
|
423 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
424 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
|
425 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
|
426 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
427 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
|
428 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
|
429 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
extern/video_depth_anything/motion_module/motion_module.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
|
2 |
+
# SPDX-License-Identifier: Apache-2.0 license
|
3 |
+
#
|
4 |
+
# This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
|
5 |
+
# Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
|
11 |
+
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
import math
|
14 |
+
|
15 |
+
try:
|
16 |
+
import xformers
|
17 |
+
import xformers.ops
|
18 |
+
|
19 |
+
XFORMERS_AVAILABLE = True
|
20 |
+
except ImportError:
|
21 |
+
print("xFormers not available")
|
22 |
+
XFORMERS_AVAILABLE = False
|
23 |
+
|
24 |
+
|
25 |
+
def zero_module(module):
|
26 |
+
# Zero out the parameters of a module and return it.
|
27 |
+
for p in module.parameters():
|
28 |
+
p.detach().zero_()
|
29 |
+
return module
|
30 |
+
|
31 |
+
|
32 |
+
class TemporalModule(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
in_channels,
|
36 |
+
num_attention_heads = 8,
|
37 |
+
num_transformer_block = 2,
|
38 |
+
num_attention_blocks = 2,
|
39 |
+
norm_num_groups = 32,
|
40 |
+
temporal_max_len = 32,
|
41 |
+
zero_initialize = True,
|
42 |
+
pos_embedding_type = "ape",
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
47 |
+
in_channels=in_channels,
|
48 |
+
num_attention_heads=num_attention_heads,
|
49 |
+
attention_head_dim=in_channels // num_attention_heads,
|
50 |
+
num_layers=num_transformer_block,
|
51 |
+
num_attention_blocks=num_attention_blocks,
|
52 |
+
norm_num_groups=norm_num_groups,
|
53 |
+
temporal_max_len=temporal_max_len,
|
54 |
+
pos_embedding_type=pos_embedding_type,
|
55 |
+
)
|
56 |
+
|
57 |
+
if zero_initialize:
|
58 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
59 |
+
|
60 |
+
def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
|
61 |
+
hidden_states = input_tensor
|
62 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
63 |
+
|
64 |
+
output = hidden_states
|
65 |
+
return output
|
66 |
+
|
67 |
+
|
68 |
+
class TemporalTransformer3DModel(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
in_channels,
|
72 |
+
num_attention_heads,
|
73 |
+
attention_head_dim,
|
74 |
+
num_layers,
|
75 |
+
num_attention_blocks = 2,
|
76 |
+
norm_num_groups = 32,
|
77 |
+
temporal_max_len = 32,
|
78 |
+
pos_embedding_type = "ape",
|
79 |
+
):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
inner_dim = num_attention_heads * attention_head_dim
|
83 |
+
|
84 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
85 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
86 |
+
|
87 |
+
self.transformer_blocks = nn.ModuleList(
|
88 |
+
[
|
89 |
+
TemporalTransformerBlock(
|
90 |
+
dim=inner_dim,
|
91 |
+
num_attention_heads=num_attention_heads,
|
92 |
+
attention_head_dim=attention_head_dim,
|
93 |
+
num_attention_blocks=num_attention_blocks,
|
94 |
+
temporal_max_len=temporal_max_len,
|
95 |
+
pos_embedding_type=pos_embedding_type,
|
96 |
+
)
|
97 |
+
for d in range(num_layers)
|
98 |
+
]
|
99 |
+
)
|
100 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
101 |
+
|
102 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
103 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
104 |
+
video_length = hidden_states.shape[2]
|
105 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
106 |
+
|
107 |
+
batch, channel, height, width = hidden_states.shape
|
108 |
+
residual = hidden_states
|
109 |
+
|
110 |
+
hidden_states = self.norm(hidden_states)
|
111 |
+
inner_dim = hidden_states.shape[1]
|
112 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
|
113 |
+
hidden_states = self.proj_in(hidden_states)
|
114 |
+
|
115 |
+
# Transformer Blocks
|
116 |
+
for block in self.transformer_blocks:
|
117 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
|
118 |
+
|
119 |
+
# output
|
120 |
+
hidden_states = self.proj_out(hidden_states)
|
121 |
+
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
122 |
+
|
123 |
+
output = hidden_states + residual
|
124 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
125 |
+
|
126 |
+
return output
|
127 |
+
|
128 |
+
|
129 |
+
class TemporalTransformerBlock(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
dim,
|
133 |
+
num_attention_heads,
|
134 |
+
attention_head_dim,
|
135 |
+
num_attention_blocks = 2,
|
136 |
+
temporal_max_len = 32,
|
137 |
+
pos_embedding_type = "ape",
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
self.attention_blocks = nn.ModuleList(
|
142 |
+
[
|
143 |
+
TemporalAttention(
|
144 |
+
query_dim=dim,
|
145 |
+
heads=num_attention_heads,
|
146 |
+
dim_head=attention_head_dim,
|
147 |
+
temporal_max_len=temporal_max_len,
|
148 |
+
pos_embedding_type=pos_embedding_type,
|
149 |
+
)
|
150 |
+
for i in range(num_attention_blocks)
|
151 |
+
]
|
152 |
+
)
|
153 |
+
self.norms = nn.ModuleList(
|
154 |
+
[
|
155 |
+
nn.LayerNorm(dim)
|
156 |
+
for i in range(num_attention_blocks)
|
157 |
+
]
|
158 |
+
)
|
159 |
+
|
160 |
+
self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
|
161 |
+
self.ff_norm = nn.LayerNorm(dim)
|
162 |
+
|
163 |
+
|
164 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
165 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
166 |
+
norm_hidden_states = norm(hidden_states)
|
167 |
+
hidden_states = attention_block(
|
168 |
+
norm_hidden_states,
|
169 |
+
encoder_hidden_states=encoder_hidden_states,
|
170 |
+
video_length=video_length,
|
171 |
+
attention_mask=attention_mask,
|
172 |
+
) + hidden_states
|
173 |
+
|
174 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
175 |
+
|
176 |
+
output = hidden_states
|
177 |
+
return output
|
178 |
+
|
179 |
+
|
180 |
+
class PositionalEncoding(nn.Module):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
d_model,
|
184 |
+
dropout = 0.,
|
185 |
+
max_len = 32
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.dropout = nn.Dropout(p=dropout)
|
189 |
+
position = torch.arange(max_len).unsqueeze(1)
|
190 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
191 |
+
pe = torch.zeros(1, max_len, d_model)
|
192 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
193 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
194 |
+
self.register_buffer('pe', pe)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
x = x + self.pe[:, :x.size(1)].to(x.dtype)
|
198 |
+
return self.dropout(x)
|
199 |
+
|
200 |
+
class TemporalAttention(CrossAttention):
|
201 |
+
def __init__(
|
202 |
+
self,
|
203 |
+
temporal_max_len = 32,
|
204 |
+
pos_embedding_type = "ape",
|
205 |
+
*args, **kwargs
|
206 |
+
):
|
207 |
+
super().__init__(*args, **kwargs)
|
208 |
+
|
209 |
+
self.pos_embedding_type = pos_embedding_type
|
210 |
+
self._use_memory_efficient_attention_xformers = True
|
211 |
+
|
212 |
+
self.pos_encoder = None
|
213 |
+
self.freqs_cis = None
|
214 |
+
if self.pos_embedding_type == "ape":
|
215 |
+
self.pos_encoder = PositionalEncoding(
|
216 |
+
kwargs["query_dim"],
|
217 |
+
dropout=0.,
|
218 |
+
max_len=temporal_max_len
|
219 |
+
)
|
220 |
+
|
221 |
+
elif self.pos_embedding_type == "rope":
|
222 |
+
self.freqs_cis = precompute_freqs_cis(
|
223 |
+
kwargs["query_dim"],
|
224 |
+
temporal_max_len
|
225 |
+
)
|
226 |
+
|
227 |
+
else:
|
228 |
+
raise NotImplementedError
|
229 |
+
|
230 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
231 |
+
d = hidden_states.shape[1]
|
232 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
233 |
+
|
234 |
+
if self.pos_encoder is not None:
|
235 |
+
hidden_states = self.pos_encoder(hidden_states)
|
236 |
+
|
237 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
238 |
+
|
239 |
+
if self.group_norm is not None:
|
240 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
241 |
+
|
242 |
+
query = self.to_q(hidden_states)
|
243 |
+
dim = query.shape[-1]
|
244 |
+
|
245 |
+
if self.added_kv_proj_dim is not None:
|
246 |
+
raise NotImplementedError
|
247 |
+
|
248 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
249 |
+
key = self.to_k(encoder_hidden_states)
|
250 |
+
value = self.to_v(encoder_hidden_states)
|
251 |
+
|
252 |
+
if self.freqs_cis is not None:
|
253 |
+
seq_len = query.shape[1]
|
254 |
+
freqs_cis = self.freqs_cis[:seq_len].to(query.device)
|
255 |
+
query, key = apply_rotary_emb(query, key, freqs_cis)
|
256 |
+
|
257 |
+
if attention_mask is not None:
|
258 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
259 |
+
target_length = query.shape[1]
|
260 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
261 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
262 |
+
|
263 |
+
|
264 |
+
use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers
|
265 |
+
if use_memory_efficient and (dim // self.heads) % 8 != 0:
|
266 |
+
# print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
|
267 |
+
use_memory_efficient = False
|
268 |
+
|
269 |
+
# attention, what we cannot get enough of
|
270 |
+
if use_memory_efficient:
|
271 |
+
query = self.reshape_heads_to_4d(query)
|
272 |
+
key = self.reshape_heads_to_4d(key)
|
273 |
+
value = self.reshape_heads_to_4d(value)
|
274 |
+
|
275 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
276 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
277 |
+
hidden_states = hidden_states.to(query.dtype)
|
278 |
+
else:
|
279 |
+
query = self.reshape_heads_to_batch_dim(query)
|
280 |
+
key = self.reshape_heads_to_batch_dim(key)
|
281 |
+
value = self.reshape_heads_to_batch_dim(value)
|
282 |
+
|
283 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
284 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
285 |
+
else:
|
286 |
+
raise NotImplementedError
|
287 |
+
# hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
288 |
+
|
289 |
+
# linear proj
|
290 |
+
hidden_states = self.to_out[0](hidden_states)
|
291 |
+
|
292 |
+
# dropout
|
293 |
+
hidden_states = self.to_out[1](hidden_states)
|
294 |
+
|
295 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
296 |
+
|
297 |
+
return hidden_states
|
extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc
ADDED
Binary file (3.24 kB). View file
|
|
extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc
ADDED
Binary file (4.72 kB). View file
|
|
extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc
ADDED
Binary file (1.66 kB). View file
|
|
extern/video_depth_anything/util/blocks.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
5 |
+
scratch = nn.Module()
|
6 |
+
|
7 |
+
out_shape1 = out_shape
|
8 |
+
out_shape2 = out_shape
|
9 |
+
out_shape3 = out_shape
|
10 |
+
if len(in_shape) >= 4:
|
11 |
+
out_shape4 = out_shape
|
12 |
+
|
13 |
+
if expand:
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape * 2
|
16 |
+
out_shape3 = out_shape * 4
|
17 |
+
if len(in_shape) >= 4:
|
18 |
+
out_shape4 = out_shape * 8
|
19 |
+
|
20 |
+
scratch.layer1_rn = nn.Conv2d(
|
21 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
22 |
+
)
|
23 |
+
scratch.layer2_rn = nn.Conv2d(
|
24 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
25 |
+
)
|
26 |
+
scratch.layer3_rn = nn.Conv2d(
|
27 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
28 |
+
)
|
29 |
+
if len(in_shape) >= 4:
|
30 |
+
scratch.layer4_rn = nn.Conv2d(
|
31 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
32 |
+
)
|
33 |
+
|
34 |
+
return scratch
|
35 |
+
|
36 |
+
|
37 |
+
class ResidualConvUnit(nn.Module):
|
38 |
+
"""Residual convolution module."""
|
39 |
+
|
40 |
+
def __init__(self, features, activation, bn):
|
41 |
+
"""Init.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
features (int): number of features
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.bn = bn
|
49 |
+
|
50 |
+
self.groups = 1
|
51 |
+
|
52 |
+
self.conv1 = nn.Conv2d(
|
53 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
54 |
+
)
|
55 |
+
|
56 |
+
self.conv2 = nn.Conv2d(
|
57 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
58 |
+
)
|
59 |
+
|
60 |
+
if self.bn is True:
|
61 |
+
self.bn1 = nn.BatchNorm2d(features)
|
62 |
+
self.bn2 = nn.BatchNorm2d(features)
|
63 |
+
|
64 |
+
self.activation = activation
|
65 |
+
|
66 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
"""Forward pass.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
x (tensor): input
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tensor: output
|
76 |
+
"""
|
77 |
+
|
78 |
+
out = self.activation(x)
|
79 |
+
out = self.conv1(out)
|
80 |
+
if self.bn is True:
|
81 |
+
out = self.bn1(out)
|
82 |
+
|
83 |
+
out = self.activation(out)
|
84 |
+
out = self.conv2(out)
|
85 |
+
if self.bn is True:
|
86 |
+
out = self.bn2(out)
|
87 |
+
|
88 |
+
if self.groups > 1:
|
89 |
+
out = self.conv_merge(out)
|
90 |
+
|
91 |
+
return self.skip_add.add(out, x)
|
92 |
+
|
93 |
+
|
94 |
+
class FeatureFusionBlock(nn.Module):
|
95 |
+
"""Feature fusion block."""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
features,
|
100 |
+
activation,
|
101 |
+
deconv=False,
|
102 |
+
bn=False,
|
103 |
+
expand=False,
|
104 |
+
align_corners=True,
|
105 |
+
size=None,
|
106 |
+
):
|
107 |
+
"""Init.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
features (int): number of features
|
111 |
+
"""
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
self.deconv = deconv
|
115 |
+
self.align_corners = align_corners
|
116 |
+
|
117 |
+
self.groups = 1
|
118 |
+
|
119 |
+
self.expand = expand
|
120 |
+
out_features = features
|
121 |
+
if self.expand is True:
|
122 |
+
out_features = features // 2
|
123 |
+
|
124 |
+
self.out_conv = nn.Conv2d(
|
125 |
+
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
|
126 |
+
)
|
127 |
+
|
128 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
129 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
130 |
+
|
131 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
132 |
+
|
133 |
+
self.size = size
|
134 |
+
|
135 |
+
def forward(self, *xs, size=None):
|
136 |
+
"""Forward pass.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
tensor: output
|
140 |
+
"""
|
141 |
+
output = xs[0]
|
142 |
+
|
143 |
+
if len(xs) == 2:
|
144 |
+
res = self.resConfUnit1(xs[1])
|
145 |
+
output = self.skip_add.add(output, res)
|
146 |
+
|
147 |
+
output = self.resConfUnit2(output)
|
148 |
+
|
149 |
+
if (size is None) and (self.size is None):
|
150 |
+
modifier = {"scale_factor": 2}
|
151 |
+
elif size is None:
|
152 |
+
modifier = {"size": self.size}
|
153 |
+
else:
|
154 |
+
modifier = {"size": size}
|
155 |
+
|
156 |
+
output = nn.functional.interpolate(
|
157 |
+
output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
|
158 |
+
)
|
159 |
+
|
160 |
+
output = self.out_conv(output)
|
161 |
+
|
162 |
+
return output
|
extern/video_depth_anything/util/transform.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
class Resize(object):
|
6 |
+
"""Resize sample to given size (width, height).
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
width,
|
12 |
+
height,
|
13 |
+
resize_target=True,
|
14 |
+
keep_aspect_ratio=False,
|
15 |
+
ensure_multiple_of=1,
|
16 |
+
resize_method="lower_bound",
|
17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
18 |
+
):
|
19 |
+
"""Init.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
width (int): desired output width
|
23 |
+
height (int): desired output height
|
24 |
+
resize_target (bool, optional):
|
25 |
+
True: Resize the full sample (image, mask, target).
|
26 |
+
False: Resize image only.
|
27 |
+
Defaults to True.
|
28 |
+
keep_aspect_ratio (bool, optional):
|
29 |
+
True: Keep the aspect ratio of the input sample.
|
30 |
+
Output sample might not have the given width and height, and
|
31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
32 |
+
Defaults to False.
|
33 |
+
ensure_multiple_of (int, optional):
|
34 |
+
Output width and height is constrained to be multiple of this parameter.
|
35 |
+
Defaults to 1.
|
36 |
+
resize_method (str, optional):
|
37 |
+
"lower_bound": Output will be at least as large as the given size.
|
38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
40 |
+
Defaults to "lower_bound".
|
41 |
+
"""
|
42 |
+
self.__width = width
|
43 |
+
self.__height = height
|
44 |
+
|
45 |
+
self.__resize_target = resize_target
|
46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
47 |
+
self.__multiple_of = ensure_multiple_of
|
48 |
+
self.__resize_method = resize_method
|
49 |
+
self.__image_interpolation_method = image_interpolation_method
|
50 |
+
|
51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
53 |
+
|
54 |
+
if max_val is not None and y > max_val:
|
55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
56 |
+
|
57 |
+
if y < min_val:
|
58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
59 |
+
|
60 |
+
return y
|
61 |
+
|
62 |
+
def get_size(self, width, height):
|
63 |
+
# determine new height and width
|
64 |
+
scale_height = self.__height / height
|
65 |
+
scale_width = self.__width / width
|
66 |
+
|
67 |
+
if self.__keep_aspect_ratio:
|
68 |
+
if self.__resize_method == "lower_bound":
|
69 |
+
# scale such that output size is lower bound
|
70 |
+
if scale_width > scale_height:
|
71 |
+
# fit width
|
72 |
+
scale_height = scale_width
|
73 |
+
else:
|
74 |
+
# fit height
|
75 |
+
scale_width = scale_height
|
76 |
+
elif self.__resize_method == "upper_bound":
|
77 |
+
# scale such that output size is upper bound
|
78 |
+
if scale_width < scale_height:
|
79 |
+
# fit width
|
80 |
+
scale_height = scale_width
|
81 |
+
else:
|
82 |
+
# fit height
|
83 |
+
scale_width = scale_height
|
84 |
+
elif self.__resize_method == "minimal":
|
85 |
+
# scale as least as possbile
|
86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
87 |
+
# fit width
|
88 |
+
scale_height = scale_width
|
89 |
+
else:
|
90 |
+
# fit height
|
91 |
+
scale_width = scale_height
|
92 |
+
else:
|
93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
94 |
+
|
95 |
+
if self.__resize_method == "lower_bound":
|
96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
98 |
+
elif self.__resize_method == "upper_bound":
|
99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
101 |
+
elif self.__resize_method == "minimal":
|
102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
106 |
+
|
107 |
+
return (new_width, new_height)
|
108 |
+
|
109 |
+
def __call__(self, sample):
|
110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
111 |
+
|
112 |
+
# resize sample
|
113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
114 |
+
|
115 |
+
if self.__resize_target:
|
116 |
+
if "depth" in sample:
|
117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
118 |
+
|
119 |
+
if "mask" in sample:
|
120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
121 |
+
|
122 |
+
return sample
|
123 |
+
|
124 |
+
|
125 |
+
class NormalizeImage(object):
|
126 |
+
"""Normlize image by given mean and std.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, mean, std):
|
130 |
+
self.__mean = mean
|
131 |
+
self.__std = std
|
132 |
+
|
133 |
+
def __call__(self, sample):
|
134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
135 |
+
|
136 |
+
return sample
|
137 |
+
|
138 |
+
|
139 |
+
class PrepareForNet(object):
|
140 |
+
"""Prepare sample for usage as network input.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self):
|
144 |
+
pass
|
145 |
+
|
146 |
+
def __call__(self, sample):
|
147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
149 |
+
|
150 |
+
if "depth" in sample:
|
151 |
+
depth = sample["depth"].astype(np.float32)
|
152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
153 |
+
|
154 |
+
if "mask" in sample:
|
155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
157 |
+
|
158 |
+
return sample
|
extern/video_depth_anything/util/util.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
def compute_scale_and_shift(prediction, target, mask, scale_only=False):
|
17 |
+
if scale_only:
|
18 |
+
return compute_scale(prediction, target, mask), 0
|
19 |
+
else:
|
20 |
+
return compute_scale_and_shift_full(prediction, target, mask)
|
21 |
+
|
22 |
+
|
23 |
+
def compute_scale(prediction, target, mask):
|
24 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
25 |
+
prediction = prediction.astype(np.float32)
|
26 |
+
target = target.astype(np.float32)
|
27 |
+
mask = mask.astype(np.float32)
|
28 |
+
|
29 |
+
a_00 = np.sum(mask * prediction * prediction)
|
30 |
+
a_01 = np.sum(mask * prediction)
|
31 |
+
a_11 = np.sum(mask)
|
32 |
+
|
33 |
+
# right hand side: b = [b_0, b_1]
|
34 |
+
b_0 = np.sum(mask * prediction * target)
|
35 |
+
|
36 |
+
x_0 = b_0 / (a_00 + 1e-6)
|
37 |
+
|
38 |
+
return x_0
|
39 |
+
|
40 |
+
def compute_scale_and_shift_full(prediction, target, mask):
|
41 |
+
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
|
42 |
+
prediction = prediction.astype(np.float32)
|
43 |
+
target = target.astype(np.float32)
|
44 |
+
mask = mask.astype(np.float32)
|
45 |
+
|
46 |
+
a_00 = np.sum(mask * prediction * prediction)
|
47 |
+
a_01 = np.sum(mask * prediction)
|
48 |
+
a_11 = np.sum(mask)
|
49 |
+
|
50 |
+
b_0 = np.sum(mask * prediction * target)
|
51 |
+
b_1 = np.sum(mask * target)
|
52 |
+
|
53 |
+
x_0 = 1
|
54 |
+
x_1 = 0
|
55 |
+
|
56 |
+
det = a_00 * a_11 - a_01 * a_01
|
57 |
+
|
58 |
+
if det != 0:
|
59 |
+
x_0 = (a_11 * b_0 - a_01 * b_1) / det
|
60 |
+
x_1 = (-a_01 * b_0 + a_00 * b_1) / det
|
61 |
+
|
62 |
+
return x_0, x_1
|
63 |
+
|
64 |
+
|
65 |
+
def get_interpolate_frames(frame_list_pre, frame_list_post):
|
66 |
+
assert len(frame_list_pre) == len(frame_list_post)
|
67 |
+
min_w = 0.0
|
68 |
+
max_w = 1.0
|
69 |
+
step = (max_w - min_w) / (len(frame_list_pre)-1)
|
70 |
+
post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
|
71 |
+
interpolated_frames = []
|
72 |
+
for i in range(len(frame_list_pre)):
|
73 |
+
interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
|
74 |
+
return interpolated_frames
|
extern/video_depth_anything/vdademo.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import argparse
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
from extern.video_depth_anything.video_depth import VideoDepthAnything
|
19 |
+
|
20 |
+
class VDADemo:
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
pre_train_path: str,
|
24 |
+
encoder: str = "vitl",
|
25 |
+
device: str = "cuda:0",
|
26 |
+
):
|
27 |
+
|
28 |
+
model_configs = {
|
29 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
30 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
31 |
+
}
|
32 |
+
|
33 |
+
self.video_depth_anything = VideoDepthAnything(**model_configs[encoder])
|
34 |
+
self.video_depth_anything.load_state_dict(torch.load(pre_train_path, map_location='cpu'), strict=True)
|
35 |
+
self.video_depth_anything = self.video_depth_anything.to(device).eval()
|
36 |
+
self.device = device
|
37 |
+
|
38 |
+
def infer(
|
39 |
+
self,
|
40 |
+
frames,
|
41 |
+
near,
|
42 |
+
far,
|
43 |
+
input_size = 518,
|
44 |
+
target_fps = -1,
|
45 |
+
):
|
46 |
+
if frames.max() < 2.:
|
47 |
+
frames = frames*255.
|
48 |
+
|
49 |
+
with torch.inference_mode():
|
50 |
+
depths, fps = self.video_depth_anything.infer_video_depth(frames, target_fps, input_size, self.device)
|
51 |
+
|
52 |
+
depths = torch.from_numpy(depths).unsqueeze(1) # 49 576 1024 ->
|
53 |
+
depths[depths < 1e-5] = 1e-5
|
54 |
+
depths = 10000. / depths
|
55 |
+
depths = depths.clip(near, far)
|
56 |
+
|
57 |
+
|
58 |
+
return depths
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
extern/video_depth_anything/video_depth.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2025) Bytedance Ltd. and/or its affiliates
|
2 |
+
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.nn as nn
|
17 |
+
from torchvision.transforms import Compose
|
18 |
+
import cv2
|
19 |
+
from tqdm import tqdm
|
20 |
+
import numpy as np
|
21 |
+
import gc
|
22 |
+
|
23 |
+
from extern.video_depth_anything.dinov2 import DINOv2
|
24 |
+
from extern.video_depth_anything.dpt_temporal import DPTHeadTemporal
|
25 |
+
from extern.video_depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
|
26 |
+
|
27 |
+
from extern.video_depth_anything.util.util import compute_scale_and_shift, get_interpolate_frames
|
28 |
+
|
29 |
+
# infer settings, do not change
|
30 |
+
INFER_LEN = 32
|
31 |
+
OVERLAP = 10
|
32 |
+
KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
|
33 |
+
INTERP_LEN = 8
|
34 |
+
|
35 |
+
class VideoDepthAnything(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
encoder='vitl',
|
39 |
+
features=256,
|
40 |
+
out_channels=[256, 512, 1024, 1024],
|
41 |
+
use_bn=False,
|
42 |
+
use_clstoken=False,
|
43 |
+
num_frames=32,
|
44 |
+
pe='ape'
|
45 |
+
):
|
46 |
+
super(VideoDepthAnything, self).__init__()
|
47 |
+
|
48 |
+
self.intermediate_layer_idx = {
|
49 |
+
'vits': [2, 5, 8, 11],
|
50 |
+
'vitl': [4, 11, 17, 23]
|
51 |
+
}
|
52 |
+
|
53 |
+
self.encoder = encoder
|
54 |
+
self.pretrained = DINOv2(model_name=encoder)
|
55 |
+
|
56 |
+
self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
B, T, C, H, W = x.shape
|
60 |
+
patch_h, patch_w = H // 14, W // 14
|
61 |
+
features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
62 |
+
depth = self.head(features, patch_h, patch_w, T)
|
63 |
+
depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
|
64 |
+
depth = F.relu(depth)
|
65 |
+
return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
|
66 |
+
|
67 |
+
def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
|
68 |
+
frame_height, frame_width = frames[0].shape[:2]
|
69 |
+
ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
|
70 |
+
if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation
|
71 |
+
input_size = int(input_size * 1.777 / ratio)
|
72 |
+
input_size = round(input_size / 14) * 14
|
73 |
+
|
74 |
+
transform = Compose([
|
75 |
+
Resize(
|
76 |
+
width=input_size,
|
77 |
+
height=input_size,
|
78 |
+
resize_target=False,
|
79 |
+
keep_aspect_ratio=True,
|
80 |
+
ensure_multiple_of=14,
|
81 |
+
resize_method='lower_bound',
|
82 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
83 |
+
),
|
84 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
85 |
+
PrepareForNet(),
|
86 |
+
])
|
87 |
+
|
88 |
+
frame_list = [frames[i] for i in range(frames.shape[0])]
|
89 |
+
frame_step = INFER_LEN - OVERLAP
|
90 |
+
org_video_len = len(frame_list)
|
91 |
+
append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
|
92 |
+
frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
|
93 |
+
|
94 |
+
depth_list = []
|
95 |
+
pre_input = None
|
96 |
+
for frame_id in tqdm(range(0, org_video_len, frame_step)):
|
97 |
+
cur_list = []
|
98 |
+
for i in range(INFER_LEN):
|
99 |
+
cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
|
100 |
+
cur_input = torch.cat(cur_list, dim=1).to(device)
|
101 |
+
if pre_input is not None:
|
102 |
+
cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
|
103 |
+
|
104 |
+
with torch.no_grad():
|
105 |
+
depth = self.forward(cur_input) # depth shape: [1, T, H, W]
|
106 |
+
|
107 |
+
depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
|
108 |
+
depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
|
109 |
+
|
110 |
+
pre_input = cur_input
|
111 |
+
|
112 |
+
del frame_list
|
113 |
+
gc.collect()
|
114 |
+
|
115 |
+
depth_list_aligned = []
|
116 |
+
ref_align = []
|
117 |
+
align_len = OVERLAP - INTERP_LEN
|
118 |
+
kf_align_list = KEYFRAMES[:align_len]
|
119 |
+
|
120 |
+
for frame_id in range(0, len(depth_list), INFER_LEN):
|
121 |
+
if len(depth_list_aligned) == 0:
|
122 |
+
depth_list_aligned += depth_list[:INFER_LEN]
|
123 |
+
for kf_id in kf_align_list:
|
124 |
+
ref_align.append(depth_list[frame_id+kf_id])
|
125 |
+
else:
|
126 |
+
curr_align = []
|
127 |
+
for i in range(len(kf_align_list)):
|
128 |
+
curr_align.append(depth_list[frame_id+i])
|
129 |
+
scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
|
130 |
+
np.concatenate(ref_align),
|
131 |
+
np.concatenate(np.ones_like(ref_align)==1))
|
132 |
+
|
133 |
+
pre_depth_list = depth_list_aligned[-INTERP_LEN:]
|
134 |
+
post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
|
135 |
+
for i in range(len(post_depth_list)):
|
136 |
+
post_depth_list[i] = post_depth_list[i] * scale + shift
|
137 |
+
post_depth_list[i][post_depth_list[i]<0] = 0
|
138 |
+
depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
|
139 |
+
|
140 |
+
for i in range(OVERLAP, INFER_LEN):
|
141 |
+
new_depth = depth_list[frame_id+i] * scale + shift
|
142 |
+
new_depth[new_depth<0] = 0
|
143 |
+
depth_list_aligned.append(new_depth)
|
144 |
+
|
145 |
+
ref_align = ref_align[:1]
|
146 |
+
for kf_id in kf_align_list[1:]:
|
147 |
+
new_depth = depth_list[frame_id+kf_id] * scale + shift
|
148 |
+
new_depth[new_depth<0] = 0
|
149 |
+
ref_align.append(new_depth)
|
150 |
+
|
151 |
+
depth_list = depth_list_aligned
|
152 |
+
|
153 |
+
return np.stack(depth_list[:org_video_len], axis=0), target_fps
|
154 |
+
|