TrajectoryCrafter commited on
Commit
0f56e8b
·
1 Parent(s): d794a86
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. app.py +284 -0
  3. demo.py +377 -0
  4. docs/config_help.md +27 -0
  5. extern/depthcrafter/__init__.py +0 -0
  6. extern/depthcrafter/__pycache__/__init__.cpython-310.pyc +0 -0
  7. extern/depthcrafter/__pycache__/demo.cpython-310.pyc +0 -0
  8. extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc +0 -0
  9. extern/depthcrafter/__pycache__/infer.cpython-310.pyc +0 -0
  10. extern/depthcrafter/__pycache__/unet.cpython-310.pyc +0 -0
  11. extern/depthcrafter/__pycache__/utils.cpython-310.pyc +0 -0
  12. extern/depthcrafter/depth_crafter_ppl.py +366 -0
  13. extern/depthcrafter/infer.py +91 -0
  14. extern/depthcrafter/unet.py +142 -0
  15. extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc +0 -0
  16. extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc +0 -0
  17. extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc +0 -0
  18. extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc +0 -0
  19. extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc +0 -0
  20. extern/video_depth_anything/dinov2.py +415 -0
  21. extern/video_depth_anything/dinov2_layers/__init__.py +11 -0
  22. extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc +0 -0
  23. extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc +0 -0
  24. extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc +0 -0
  25. extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc +0 -0
  26. extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
  27. extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc +0 -0
  28. extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
  29. extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
  30. extern/video_depth_anything/dinov2_layers/attention.py +83 -0
  31. extern/video_depth_anything/dinov2_layers/block.py +252 -0
  32. extern/video_depth_anything/dinov2_layers/drop_path.py +35 -0
  33. extern/video_depth_anything/dinov2_layers/layer_scale.py +28 -0
  34. extern/video_depth_anything/dinov2_layers/mlp.py +41 -0
  35. extern/video_depth_anything/dinov2_layers/patch_embed.py +89 -0
  36. extern/video_depth_anything/dinov2_layers/swiglu_ffn.py +63 -0
  37. extern/video_depth_anything/dpt.py +160 -0
  38. extern/video_depth_anything/dpt_temporal.py +96 -0
  39. extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc +0 -0
  40. extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc +0 -0
  41. extern/video_depth_anything/motion_module/attention.py +429 -0
  42. extern/video_depth_anything/motion_module/motion_module.py +297 -0
  43. extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc +0 -0
  44. extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc +0 -0
  45. extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc +0 -0
  46. extern/video_depth_anything/util/blocks.py +162 -0
  47. extern/video_depth_anything/util/transform.py +158 -0
  48. extern/video_depth_anything/util/util.py +74 -0
  49. extern/video_depth_anything/vdademo.py +63 -0
  50. extern/video_depth_anything/video_depth.py +154 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import sys
4
+ from demo import TrajCrafter
5
+ import random
6
+ import gradio as gr
7
+ import random
8
+ from inference import get_parser
9
+ from datetime import datetime
10
+ import argparse
11
+
12
+ # 解析命令行参数
13
+
14
+ traj_examples = [
15
+ ['20; -30; 0.3; 0; 0'],
16
+ ['0; 0; -0.3; -2; 2'],
17
+ ]
18
+
19
+ # inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
20
+
21
+ img_examples = [
22
+ ['test/videos/0-NNvgaTcVzAG0-r.mp4',2,1,'0; -30; 0.5; -2; 0',50,43],
23
+ ['test/videos/tUfDESZsQFhdDW9S.mp4',2,1,'0; 30; -0.4; 2; 0',50,43],
24
+ ['test/videos/part-2-3.mp4',2,1,'20; 40; 0.5; 2; 0',50,43],
25
+ ['test/videos/p7.mp4',2,1,'0; -50; 0.3; 0; 0',50,43],
26
+ ['test/videos/UST-fn-RvhJwMR5S.mp4',2,1,'0; -35; 0.4; 0; 0',50,43],
27
+ ]
28
+
29
+ max_seed = 2 ** 31
30
+
31
+ parser = get_parser() # infer_config.py
32
+ opts = parser.parse_args() # default device: 'cuda:0'
33
+ opts.weight_dtype = torch.bfloat16
34
+ tmp = datetime.now().strftime("%Y%m%d_%H%M")
35
+ opts.save_dir = f'./experiments/gradio_{tmp}'
36
+ os.makedirs(opts.save_dir,exist_ok=True)
37
+ test_tensor = torch.Tensor([0]).cuda()
38
+ opts.device = str(test_tensor.device)
39
+
40
+ CAMERA_MOTION_MODE = ["Basic Camera Trajectory", "Custom Camera Trajectory"]
41
+
42
+ def show_traj(mode):
43
+ if mode == 'Orbit Left':
44
+ return gr.update(value='0; -30; 0; 0; 0',visible=True),gr.update(visible=False)
45
+ elif mode == 'Orbit Right':
46
+ return gr.update(value='0; 30; 0; 0; 0',visible=True),gr.update(visible=False)
47
+ elif mode == 'Orbit Up':
48
+ return gr.update(value='30; 0; 0; 0; 0',visible=True),gr.update(visible=False)
49
+ elif mode == 'Orbit Down':
50
+ return gr.update(value='-20; 0; 0; 0; 0',visible=True), gr.update(visible=False)
51
+ if mode == 'Pan Left':
52
+ return gr.update(value='0; 0; 0; -2; 0',visible=True),gr.update(visible=False)
53
+ elif mode == 'Pan Right':
54
+ return gr.update(value='0; 0; 0; 2; 0',visible=True),gr.update(visible=False)
55
+ elif mode == 'Pan Up':
56
+ return gr.update(value='0; 0; 0; 0; 2',visible=True),gr.update(visible=False)
57
+ elif mode == 'Pan Down':
58
+ return gr.update(value='0; 0; 0; 0; -2',visible=True), gr.update(visible=False)
59
+ elif mode == 'Zoom in':
60
+ return gr.update(value='0; 0; 0.5; 0; 0',visible=True), gr.update(visible=False)
61
+ elif mode == 'Zoom out':
62
+ return gr.update(value='0; 0; -0.5; 0; 0',visible=True), gr.update(visible=False)
63
+ elif mode == 'Customize':
64
+ return gr.update(value='0; 0; 0; 0; 0',visible=True), gr.update(visible=True)
65
+ elif mode == 'Reset':
66
+ return gr.update(value='0; 0; 0; 0; 0',visible=False), gr.update(visible=False)
67
+
68
+ def trajcrafter_demo(opts):
69
+ # css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height:576px} #random_button {max-width: 100px !important}"""
70
+ css = """
71
+ #input_img {max-width: 1024px !important}
72
+ #output_vid {max-width: 1024px; max-height:576px}
73
+ #random_button {max-width: 100px !important}
74
+ .generate-btn {
75
+ background: linear-gradient(45deg, #2196F3, #1976D2) !important;
76
+ border: none !important;
77
+ color: white !important;
78
+ font-weight: bold !important;
79
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2) !important;
80
+ }
81
+ .generate-btn:hover {
82
+ background: linear-gradient(45deg, #1976D2, #1565C0) !important;
83
+ box-shadow: 0 4px 8px rgba(0,0,0,0.3) !important;
84
+ }
85
+ """
86
+ image2video = TrajCrafter(opts,gradio=True)
87
+ # image2video.run_both = spaces.GPU(image2video.run_both, duration=290) # fixme
88
+ with gr.Blocks(analytics_enabled=False, css=css) as trajcrafter_iface:
89
+ gr.Markdown("<div align='center'> <h1> TrajectoryCrafter: Redirecting View Trajectory for Monocular Videos via Diffusion Models </span> </h1>")
90
+ # # <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
91
+ # # <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02048'> [ArXiv] </a>\
92
+ # # <a style='font-size:18px;color: #000000' href='https://drexubery.github.io/ViewCrafter/'> [Project Page] </a>\
93
+ # # <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Drexubery/ViewCrafter'> [Github] </a>\
94
+ # # <a style='font-size:18px;color: #000000' href='https://www.youtube.com/watch?v=WGIEmu9eXmU'> [Video] </a> </div>")
95
+
96
+
97
+ with gr.Row(equal_height=True):
98
+ with gr.Column():
99
+ # # step 1: input an image
100
+ # gr.Markdown("---\n## Step 1: Input an Image, selet an elevation angle and a center_scale factor", show_label=False, visible=True)
101
+ # gr.Markdown("<div align='left' style='font-size:18px;color: #000000'>1. Estimate an elevation angle that represents the angle at which the image was taken; a value bigger than 0 indicates a top-down view, and it doesn't need to be precise. <br>2. The origin of the world coordinate system is by default defined at the point cloud corresponding to the center pixel of the input image. You can adjust the position of the origin by modifying center_scale; a value smaller than 1 brings the origin closer to you.</div>")
102
+ i2v_input_video = gr.Video(label="Input Video", elem_id="input_video", format="mp4")
103
+
104
+
105
+ with gr.Column():
106
+ i2v_output_video = gr.Video(label="Generated Video", elem_id="output_vid", autoplay=True,
107
+ show_share_button=True)
108
+
109
+ with gr.Row():
110
+ with gr.Row():
111
+ i2v_stride = gr.Slider(minimum=1, maximum=3, step=1, elem_id="stride", label="Stride", value=1)
112
+ i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale",
113
+ label="center_scale", value=1)
114
+ i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps",
115
+ value=50)
116
+ i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
117
+ with gr.Row():
118
+ pan_left = gr.Button(value="Pan Left")
119
+ pan_right = gr.Button(value="Pan Right")
120
+ pan_up = gr.Button(value="Pan Up")
121
+ pan_down = gr.Button(value="Pan Down")
122
+ with gr.Row():
123
+ orbit_left = gr.Button(value="Orbit Left")
124
+ orbit_right = gr.Button(value="Orbit Right")
125
+ orbit_up = gr.Button(value="Orbit Up")
126
+ orbit_down = gr.Button(value="Orbit Down")
127
+ with gr.Row():
128
+ zin = gr.Button(value="Zoom in")
129
+ zout = gr.Button(value="Zoom out")
130
+ custom = gr.Button(value="Customize")
131
+ reset = gr.Button(value="Reset")
132
+ with gr.Column():
133
+ with gr.Row():
134
+ with gr.Column():
135
+ i2v_pose = gr.Text(value='0; 0; 0; 0; 0', label="Traget camera pose (theta, phi, r, x, y)",
136
+ visible=False)
137
+ with gr.Column(visible=False) as i2v_egs:
138
+ gr.Markdown(
139
+ "<div align='left' style='font-size:18px;color: #000000'>Please refer to <a href='https://github.com/TrajectoryCrafter/TrajectoryCrafter/blob/main/docs/config_help.md' target='_blank'>tutorial</a> for customizing camera trajectory.</div>")
140
+ gr.Examples(examples=traj_examples,
141
+ inputs=[i2v_pose],
142
+ )
143
+ with gr.Column():
144
+ i2v_end_btn = gr.Button("Generate video", scale=2, size="lg", variant="primary", elem_classes="generate-btn")
145
+
146
+
147
+ # with gr.Column():
148
+ # i2v_input_video = gr.Video(label="Input Video", elem_id="input_video", format="mp4")
149
+ # i2v_input_image = gr.Image(label="Input Image",elem_id="input_img")
150
+ # with gr.Row():
151
+ # # i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
152
+ # i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
153
+ # i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
154
+ # i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
155
+ # with gr.Column():
156
+ # with gr.Row():
157
+ # left = gr.Button(value = "Left")
158
+ # right = gr.Button(value = "Right")
159
+ # up = gr.Button(value = "Up")
160
+ # with gr.Row():
161
+ # down = gr.Button(value = "Down")
162
+ # zin = gr.Button(value = "Zoom in")
163
+ # zout = gr.Button(value = "Zoom out")
164
+ # with gr.Row():
165
+ # custom = gr.Button(value = "Customize")
166
+ # reset = gr.Button(value = "Reset")
167
+
168
+
169
+ # step 3 - Generate video
170
+ # with gr.Column():
171
+ # gr.Markdown("---\n## Step 3: Generate video", show_label=False, visible=True)
172
+ # gr.Markdown("<div align='left' style='font-size:18px;color: #000000'> You can reduce the sampling steps for faster inference; try different random seed if the result is not satisfying. </div>")
173
+ # i2v_output_video = gr.Video(label="Generated Video",elem_id="output_vid",autoplay=True,show_share_button=True)
174
+ # i2v_end_btn = gr.Button("Generate video")
175
+ # i2v_traj_video = gr.Video(label="Camera Trajectory",elem_id="traj_vid",autoplay=True,show_share_button=True)
176
+
177
+ # with gr.Column(scale=1.5):
178
+ # with gr.Row():
179
+ # # i2v_elevation = gr.Slider(minimum=-45, maximum=45, step=1, elem_id="elevation", label="elevation", value=5)
180
+ # i2v_center_scale = gr.Slider(minimum=0.1, maximum=2, step=0.1, elem_id="i2v_center_scale", label="center_scale", value=1)
181
+ # i2v_steps = gr.Slider(minimum=1, maximum=50, step=1, elem_id="i2v_steps", label="Sampling steps", value=50)
182
+ # i2v_seed = gr.Slider(label='Random seed', minimum=0, maximum=max_seed, step=1, value=43)
183
+ # with gr.Row():
184
+ # pan_left = gr.Button(value = "Pan Left")
185
+ # pan_right = gr.Button(value = "Pan Right")
186
+ # pan_up = gr.Button(value = "Pan Up")
187
+ # pan_down = gr.Button(value = "Pan Down")
188
+ # with gr.Row():
189
+ # orbit_left = gr.Button(value = "Orbit Left")
190
+ # orbit_right = gr.Button(value = "Orbit Right")
191
+ # orbit_up = gr.Button(value = "Orbit Up")
192
+ # orbit_down = gr.Button(value = "Orbit Down")
193
+ # with gr.Row():
194
+ # zin = gr.Button(value = "Zoom in")
195
+ # zout = gr.Button(value = "Zoom out")
196
+ # custom = gr.Button(value = "Customize")
197
+ # reset = gr.Button(value = "Reset")
198
+ # with gr.Column():
199
+ # with gr.Row():
200
+ # with gr.Column():
201
+ # i2v_pose = gr.Text(value = '0; 0; 0; 0; 0', label="Traget camera pose (theta, phi, r, x, y)",visible=False)
202
+ # with gr.Column(visible=False) as i2v_egs:
203
+ # gr.Markdown("<div align='left' style='font-size:18px;color: #000000'>Please refer to the <a href='https://github.com/Drexubery/ViewCrafter/blob/main/docs/gradio_tutorial.md' target='_blank'>tutorial</a> for customizing camera trajectory.</div>")
204
+ # gr.Examples(examples=traj_examples,
205
+ # inputs=[i2v_pose],
206
+ # )
207
+ # with gr.Row():
208
+ # i2v_end_btn = gr.Button("Generate video")
209
+ # step 3 - Generate video
210
+ # with gr.Row():
211
+ # with gr.Column():
212
+
213
+
214
+
215
+ i2v_end_btn.click(inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
216
+ outputs=[i2v_output_video],
217
+ fn = image2video.run_gradio
218
+ )
219
+
220
+ pan_left.click(inputs=[pan_left],
221
+ outputs=[i2v_pose,i2v_egs],
222
+ fn = show_traj
223
+ )
224
+ pan_right.click(inputs=[pan_right],
225
+ outputs=[i2v_pose,i2v_egs],
226
+ fn = show_traj
227
+ )
228
+ pan_up.click(inputs=[pan_up],
229
+ outputs=[i2v_pose,i2v_egs],
230
+ fn = show_traj
231
+ )
232
+ pan_down.click(inputs=[pan_down],
233
+ outputs=[i2v_pose,i2v_egs],
234
+ fn = show_traj
235
+ )
236
+ orbit_left.click(inputs=[orbit_left],
237
+ outputs=[i2v_pose,i2v_egs],
238
+ fn = show_traj
239
+ )
240
+ orbit_right.click(inputs=[orbit_right],
241
+ outputs=[i2v_pose,i2v_egs],
242
+ fn = show_traj
243
+ )
244
+ orbit_up.click(inputs=[orbit_up],
245
+ outputs=[i2v_pose,i2v_egs],
246
+ fn = show_traj
247
+ )
248
+ orbit_down.click(inputs=[orbit_down],
249
+ outputs=[i2v_pose,i2v_egs],
250
+ fn = show_traj
251
+ )
252
+ zin.click(inputs=[zin],
253
+ outputs=[i2v_pose,i2v_egs],
254
+ fn = show_traj
255
+ )
256
+ zout.click(inputs=[zout],
257
+ outputs=[i2v_pose,i2v_egs],
258
+ fn = show_traj
259
+ )
260
+ custom.click(inputs=[custom],
261
+ outputs=[i2v_pose,i2v_egs],
262
+ fn = show_traj
263
+ )
264
+ reset.click(inputs=[reset],
265
+ outputs=[i2v_pose,i2v_egs],
266
+ fn = show_traj
267
+ )
268
+
269
+
270
+ gr.Examples(examples=img_examples,
271
+ # inputs=[i2v_input_video,i2v_stride],
272
+ inputs=[i2v_input_video, i2v_stride, i2v_center_scale, i2v_pose, i2v_steps, i2v_seed],
273
+ )
274
+
275
+ return trajcrafter_iface
276
+
277
+
278
+ trajcrafter_iface = trajcrafter_demo(opts)
279
+ trajcrafter_iface.queue(max_size=10)
280
+ # trajcrafter_iface.launch(server_name=args.server_name, max_threads=10, debug=True)
281
+ trajcrafter_iface.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False, max_threads=10)
282
+
283
+
284
+
demo.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import torch
4
+ from extern.depthcrafter.infer import DepthCrafterDemo
5
+ # from extern.video_depth_anything.vdademo import VDADemo
6
+ import numpy as np
7
+ import torch
8
+ from transformers import T5EncoderModel
9
+ from omegaconf import OmegaConf
10
+ from PIL import Image
11
+ from models.crosstransformer3d import CrossTransformer3DModel
12
+ from models.autoencoder_magvit import AutoencoderKLCogVideoX
13
+ from models.pipeline_trajectorycrafter import TrajCrafter_Pipeline
14
+ from models.utils import *
15
+ from diffusers import (AutoencoderKL, CogVideoXDDIMScheduler, DDIMScheduler,
16
+ DPMSolverMultistepScheduler,
17
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
18
+ PNDMScheduler)
19
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
20
+
21
+ class TrajCrafter:
22
+ def __init__(self, opts, gradio=False):
23
+ self.funwarp = Warper(device=opts.device)
24
+ # self.depth_estimater = VDADemo(pre_train_path=opts.pre_train_path_vda,device=opts.device)
25
+ self.depth_estimater = DepthCrafterDemo(unet_path=opts.unet_path,pre_train_path=opts.pre_train_path,cpu_offload=opts.cpu_offload,device=opts.device)
26
+ self.caption_processor = AutoProcessor.from_pretrained(opts.blip_path)
27
+ self.captioner = Blip2ForConditionalGeneration.from_pretrained(opts.blip_path, torch_dtype=torch.float16).to(opts.device)
28
+ self.setup_diffusion(opts)
29
+ if gradio:
30
+ self.opts=opts
31
+
32
+ def infer_gradual(self,opts):
33
+ frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
34
+ prompt = self.get_caption(opts,frames[opts.video_length//2])
35
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
36
+ depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
37
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
38
+ assert frames.shape[0] == opts.video_length
39
+ pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.video_length)
40
+ warped_images = []
41
+ masks = []
42
+ for i in tqdm(range(opts.video_length)):
43
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, opts.mask,twice=False)
44
+ warped_images.append(warped_frame2)
45
+ masks.append(mask2)
46
+ cond_video = (torch.cat(warped_images)+1.)/2.
47
+ cond_masks = torch.cat(masks)
48
+
49
+ frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
50
+ cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
51
+ cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
52
+ save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
53
+ save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
54
+ save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
55
+
56
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
57
+ frames_ref = frames[:,:,:10,:,:]
58
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
59
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
60
+ generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
61
+
62
+ del self.depth_estimater
63
+ del self.caption_processor
64
+ del self.captioner
65
+ gc.collect()
66
+ torch.cuda.empty_cache()
67
+ with torch.no_grad():
68
+ sample = self.pipeline(
69
+ prompt,
70
+ num_frames = opts.video_length,
71
+ negative_prompt = opts.negative_prompt,
72
+ height = opts.sample_size[0],
73
+ width = opts.sample_size[1],
74
+ generator = generator,
75
+ guidance_scale = opts.diffusion_guidance_scale,
76
+ num_inference_steps = opts.diffusion_inference_steps,
77
+ video = cond_video,
78
+ mask_video = cond_masks,
79
+ reference = frames_ref,
80
+ ).videos
81
+ save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
82
+
83
+ viz = True
84
+ if viz:
85
+ tensor_left = frames[0].to(opts.device)
86
+ tensor_right = sample[0].to(opts.device)
87
+ interval = torch.ones(3, 49, 384, 30).to(opts.device)
88
+ result = torch.cat((tensor_left, interval, tensor_right), dim=3)
89
+ result_reverse = torch.flip(result, dims=[1])
90
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
91
+ save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2)
92
+
93
+ def infer_direct(self,opts):
94
+ opts.cut = 20
95
+ frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
96
+ prompt = self.get_caption(opts,frames[opts.video_length//2])
97
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
98
+ depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
99
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
100
+ assert frames.shape[0] == opts.video_length
101
+ pose_s, pose_t, K = self.get_poses(opts,depths,num_frames = opts.cut)
102
+
103
+ warped_images = []
104
+ masks = []
105
+ for i in tqdm(range(opts.video_length)):
106
+ if i < opts.cut:
107
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[0:1], None, depths[0:1], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False)
108
+ warped_images.append(warped_frame2)
109
+ masks.append(mask2)
110
+ else:
111
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i-opts.cut:i-opts.cut+1], None, depths[i-opts.cut:i-opts.cut+1], pose_s[0:1], pose_t[-1:], K[0:1], None, opts.mask,twice=False)
112
+ warped_images.append(warped_frame2)
113
+ masks.append(mask2)
114
+ cond_video = (torch.cat(warped_images)+1.)/2.
115
+ cond_masks = torch.cat(masks)
116
+ frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
117
+ cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
118
+ cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
119
+ save_video((frames[:opts.video_length-opts.cut].permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
120
+ save_video(cond_video[opts.cut:].permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
121
+ save_video(cond_masks[opts.cut:].repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
122
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
123
+ frames_ref = frames[:,:,:10,:,:]
124
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
125
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
126
+ generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
127
+
128
+ del self.depth_estimater
129
+ del self.caption_processor
130
+ del self.captioner
131
+ gc.collect()
132
+ torch.cuda.empty_cache()
133
+ with torch.no_grad():
134
+ sample = self.pipeline(
135
+ prompt,
136
+ num_frames = opts.video_length,
137
+ negative_prompt = opts.negative_prompt,
138
+ height = opts.sample_size[0],
139
+ width = opts.sample_size[1],
140
+ generator = generator,
141
+ guidance_scale = opts.diffusion_guidance_scale,
142
+ num_inference_steps = opts.diffusion_inference_steps,
143
+ video = cond_video,
144
+ mask_video = cond_masks,
145
+ reference = frames_ref,
146
+ ).videos
147
+ save_video(sample[0].permute(1,2,3,0)[opts.cut:], os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
148
+
149
+ viz = True
150
+ if viz:
151
+ tensor_left = frames[0][:,:opts.video_length-opts.cut,...].to(opts.device)
152
+ tensor_right = sample[0][:,opts.cut:,...].to(opts.device)
153
+ interval = torch.ones(3, opts.video_length-opts.cut, 384, 30).to(opts.device)
154
+ result = torch.cat((tensor_left, interval, tensor_right), dim=3)
155
+ result_reverse = torch.flip(result, dims=[1])
156
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
157
+ save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*2)
158
+
159
+ def infer_bullet(self,opts):
160
+ frames = read_video_frames(opts.video_path,opts.video_length,opts.stride,opts.max_res)
161
+ prompt = self.get_caption(opts,frames[opts.video_length//2])
162
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
163
+ depths= self.depth_estimater.infer(frames, opts.near, opts.far, opts.depth_inference_steps, opts.depth_guidance_scale, window_size=opts.window_size, overlap=opts.overlap).to(opts.device)
164
+
165
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
166
+ assert frames.shape[0] == opts.video_length
167
+ pose_s, pose_t, K = self.get_poses(opts,depths, num_frames = opts.video_length)
168
+
169
+ warped_images = []
170
+ masks = []
171
+ for i in tqdm(range(opts.video_length)):
172
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[-1:], None, depths[-1:], pose_s[0:1], pose_t[i:i+1], K[0:1], None, opts.mask,twice=False)
173
+ warped_images.append(warped_frame2)
174
+ masks.append(mask2)
175
+ cond_video = (torch.cat(warped_images)+1.)/2.
176
+ cond_masks = torch.cat(masks)
177
+ frames = F.interpolate(frames, size=opts.sample_size, mode='bilinear', align_corners=False)
178
+ cond_video = F.interpolate(cond_video, size=opts.sample_size, mode='bilinear', align_corners=False)
179
+ cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
180
+ save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(opts.save_dir,'input.mp4'),fps=opts.fps)
181
+ save_video(cond_video.permute(0,2,3,1), os.path.join(opts.save_dir,'render.mp4'),fps=opts.fps)
182
+ save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(opts.save_dir,'mask.mp4'),fps=opts.fps)
183
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
184
+ frames_ref = frames[:,:,-10:,:,:]
185
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
186
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
187
+ generator = torch.Generator(device=opts.device).manual_seed(opts.seed)
188
+
189
+ del self.depth_estimater
190
+ del self.caption_processor
191
+ del self.captioner
192
+ gc.collect()
193
+ torch.cuda.empty_cache()
194
+ with torch.no_grad():
195
+ sample = self.pipeline(
196
+ prompt,
197
+ num_frames = opts.video_length,
198
+ negative_prompt = opts.negative_prompt,
199
+ height = opts.sample_size[0],
200
+ width = opts.sample_size[1],
201
+ generator = generator,
202
+ guidance_scale = opts.diffusion_guidance_scale,
203
+ num_inference_steps = opts.diffusion_inference_steps,
204
+ video = cond_video,
205
+ mask_video = cond_masks,
206
+ reference = frames_ref,
207
+ ).videos
208
+ save_video(sample[0].permute(1,2,3,0), os.path.join(opts.save_dir,'gen.mp4'), fps=opts.fps)
209
+
210
+ viz = True
211
+ if viz:
212
+ tensor_left = frames[0].to(opts.device)
213
+ tensor_left_full = torch.cat([tensor_left,tensor_left[:,-1:,:,:].repeat(1,48,1,1)],dim=1)
214
+ tensor_right = sample[0].to(opts.device)
215
+ tensor_right_full = torch.cat([tensor_left,tensor_right[:,1:,:,:]],dim=1)
216
+ interval = torch.ones(3, 49*2-1, 384, 30).to(opts.device)
217
+ result = torch.cat((tensor_left_full, interval, tensor_right_full), dim=3)
218
+ result_reverse = torch.flip(result, dims=[1])
219
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
220
+ save_video(final_result.permute(1,2,3,0), os.path.join(opts.save_dir,'viz.mp4'), fps=opts.fps*4)
221
+
222
+ def get_caption(self,opts,image):
223
+ image_array = (image * 255).astype(np.uint8)
224
+ pil_image = Image.fromarray(image_array)
225
+ inputs = self.caption_processor(images=pil_image, return_tensors="pt").to(opts.device, torch.float16)
226
+ generated_ids = self.captioner.generate(**inputs)
227
+ generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
228
+ return generated_text + opts.refine_prompt
229
+
230
+ def get_poses(self,opts,depths,num_frames):
231
+ radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*opts.radius_scale
232
+ radius = min(radius, 5)
233
+ cx = 512. #depths.shape[-1]//2
234
+ cy = 288. #depths.shape[-2]//2
235
+ f = 500 #500.
236
+ K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(opts.device)
237
+ c2w_init = torch.tensor([[-1., 0., 0., 0.],
238
+ [ 0., 1., 0., 0.],
239
+ [ 0., 0., -1., 0.],
240
+ [ 0., 0., 0., 1.]]).to(opts.device).unsqueeze(0)
241
+ if opts.camera == 'target':
242
+ dtheta, dphi, dr, dx, dy = opts.target_pose
243
+ poses = generate_traj_specified(c2w_init, dtheta, dphi, dr*radius, dx, dy, num_frames, opts.device)
244
+ elif opts.camera =='traj':
245
+ with open(opts.traj_txt, 'r') as file:
246
+ lines = file.readlines()
247
+ theta = [float(i) for i in lines[0].split()]
248
+ phi = [float(i) for i in lines[1].split()]
249
+ r = [float(i)*radius for i in lines[2].split()]
250
+ poses = generate_traj_txt(c2w_init, phi, theta, r, num_frames, opts.device)
251
+ poses[:,2, 3] = poses[:,2, 3] + radius
252
+ pose_s = poses[opts.anchor_idx:opts.anchor_idx+1].repeat(num_frames,1,1)
253
+ pose_t = poses
254
+ return pose_s, pose_t, K
255
+
256
+ def setup_diffusion(self,opts):
257
+ # transformer = CrossTransformer3DModel.from_pretrained_cus(opts.transformer_path).to(opts.weight_dtype)
258
+ transformer = CrossTransformer3DModel.from_pretrained(opts.transformer_path).to(opts.weight_dtype)
259
+ # transformer = transformer.to(opts.weight_dtype)
260
+ vae = AutoencoderKLCogVideoX.from_pretrained(
261
+ opts.model_name,
262
+ subfolder="vae"
263
+ ).to(opts.weight_dtype)
264
+ text_encoder = T5EncoderModel.from_pretrained(
265
+ opts.model_name, subfolder="text_encoder", torch_dtype=opts.weight_dtype
266
+ )
267
+ # Get Scheduler
268
+ Choosen_Scheduler = {
269
+ "Euler": EulerDiscreteScheduler,
270
+ "Euler A": EulerAncestralDiscreteScheduler,
271
+ "DPM++": DPMSolverMultistepScheduler,
272
+ "PNDM": PNDMScheduler,
273
+ "DDIM_Cog": CogVideoXDDIMScheduler,
274
+ "DDIM_Origin": DDIMScheduler,
275
+ }[opts.sampler_name]
276
+ scheduler = Choosen_Scheduler.from_pretrained(
277
+ opts.model_name,
278
+ subfolder="scheduler"
279
+ )
280
+
281
+ self.pipeline = TrajCrafter_Pipeline.from_pretrained(
282
+ opts.model_name,
283
+ vae=vae,
284
+ text_encoder=text_encoder,
285
+ transformer=transformer,
286
+ scheduler=scheduler,
287
+ torch_dtype=opts.weight_dtype
288
+ )
289
+
290
+ if opts.low_gpu_memory_mode:
291
+ self.pipeline.enable_sequential_cpu_offload()
292
+ else:
293
+ self.pipeline.enable_model_cpu_offload()
294
+
295
+ def run_gradio(self,input_video, stride, radius_scale, pose, steps, seed):
296
+ frames = read_video_frames(input_video, self.opts.video_length, stride,self.opts.max_res)
297
+ prompt = self.get_caption(self.opts,frames[self.opts.video_length//2])
298
+ # depths= self.depth_estimater.infer(frames, opts.near, opts.far).to(opts.device)
299
+ depths= self.depth_estimater.infer(frames, self.opts.near, self.opts.far, self.opts.depth_inference_steps, self.opts.depth_guidance_scale, window_size=self.opts.window_size, overlap=self.opts.overlap).to(self.opts.device)
300
+ frames = torch.from_numpy(frames).permute(0,3,1,2).to(self.opts.device)*2.-1. # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
301
+ num_frames = frames.shape[0]
302
+ assert num_frames == self.opts.video_length
303
+ radius_scale = float(radius_scale)
304
+ radius = depths[0,0,depths.shape[-2]//2,depths.shape[-1]//2].cpu()*radius_scale
305
+ radius = min(radius, 5)
306
+ cx = 512. #depths.shape[-1]//2
307
+ cy = 288. #depths.shape[-2]//2
308
+ f = 500 #500.
309
+ K = torch.tensor([[f, 0., cx],[ 0., f, cy],[ 0., 0., 1.]]).repeat(num_frames,1,1).to(self.opts.device)
310
+ c2w_init = torch.tensor([[-1., 0., 0., 0.],
311
+ [ 0., 1., 0., 0.],
312
+ [ 0., 0., -1., 0.],
313
+ [ 0., 0., 0., 1.]]).to(self.opts.device).unsqueeze(0)
314
+
315
+ # import pdb
316
+ # pdb.set_trace()
317
+ theta,phi,r,x,y = [float(i) for i in pose.split(';')]
318
+ # theta,phi,r,x,y = [float(i) for i in theta.split()],[float(i) for i in phi.split()],[float(i) for i in r.split()],[float(i) for i in x.split()],[float(i) for i in y.split()]
319
+ # target mode
320
+ poses = generate_traj_specified(c2w_init, theta, phi, r*radius, x, y, num_frames, self.opts.device)
321
+ poses[:,2, 3] = poses[:,2, 3] + radius
322
+ pose_s = poses[self.opts.anchor_idx:self.opts.anchor_idx+1].repeat(num_frames,1,1)
323
+ pose_t = poses
324
+
325
+ warped_images = []
326
+ masks = []
327
+ for i in tqdm(range(self.opts.video_length)):
328
+ warped_frame2, mask2, warped_depth2, flow12 = self.funwarp.forward_warp(frames[i:i+1], None, depths[i:i+1], pose_s[i:i+1], pose_t[i:i+1], K[i:i+1], None, self.opts.mask,twice=False)
329
+ warped_images.append(warped_frame2)
330
+ masks.append(mask2)
331
+ cond_video = (torch.cat(warped_images)+1.)/2.
332
+ cond_masks = torch.cat(masks)
333
+
334
+ frames = F.interpolate(frames, size=self.opts.sample_size, mode='bilinear', align_corners=False)
335
+ cond_video = F.interpolate(cond_video, size=self.opts.sample_size, mode='bilinear', align_corners=False)
336
+ cond_masks = F.interpolate(cond_masks, size=self.opts.sample_size, mode='nearest')
337
+ save_video((frames.permute(0,2,3,1)+1.)/2., os.path.join(self.opts.save_dir,'input.mp4'),fps=self.opts.fps)
338
+ save_video(cond_video.permute(0,2,3,1), os.path.join(self.opts.save_dir,'render.mp4'),fps=self.opts.fps)
339
+ save_video(cond_masks.repeat(1,3,1,1).permute(0,2,3,1), os.path.join(self.opts.save_dir,'mask.mp4'),fps=self.opts.fps)
340
+
341
+ frames = (frames.permute(1,0,2,3).unsqueeze(0)+1.)/2.
342
+ frames_ref = frames[:,:,:10,:,:]
343
+ cond_video = cond_video.permute(1,0,2,3).unsqueeze(0)
344
+ cond_masks = (1.-cond_masks.permute(1,0,2,3).unsqueeze(0))*255.
345
+ generator = torch.Generator(device=self.opts.device).manual_seed(seed)
346
+
347
+ # del self.depth_estimater
348
+ # del self.caption_processor
349
+ # del self.captioner
350
+ # gc.collect()
351
+ torch.cuda.empty_cache()
352
+ with torch.no_grad():
353
+ sample = self.pipeline(
354
+ prompt,
355
+ num_frames = self.opts.video_length,
356
+ negative_prompt = self.opts.negative_prompt,
357
+ height = self.opts.sample_size[0],
358
+ width = self.opts.sample_size[1],
359
+ generator = generator,
360
+ guidance_scale = self.opts.diffusion_guidance_scale,
361
+ num_inference_steps = steps,
362
+ video = cond_video,
363
+ mask_video = cond_masks,
364
+ reference = frames_ref,
365
+ ).videos
366
+ save_video(sample[0].permute(1,2,3,0), os.path.join(self.opts.save_dir,'gen.mp4'), fps=self.opts.fps)
367
+
368
+ viz = True
369
+ if viz:
370
+ tensor_left = frames[0].to(self.opts.device)
371
+ tensor_right = sample[0].to(self.opts.device)
372
+ interval = torch.ones(3, 49, 384, 30).to(self.opts.device)
373
+ result = torch.cat((tensor_left, interval, tensor_right), dim=3)
374
+ result_reverse = torch.flip(result, dims=[1])
375
+ final_result = torch.cat((result, result_reverse[:,1:,:,:]), dim=1)
376
+ save_video(final_result.permute(1,2,3,0), os.path.join(self.opts.save_dir,'viz.mp4'), fps=self.opts.fps*2)
377
+ return os.path.join(self.opts.save_dir,'viz.mp4')
docs/config_help.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Important configuration for [inference.py](../inference.py):
2
+
3
+ ### 1. General configs
4
+ | Configuration | Default Value | Explanation |
5
+ |:----------------- |:--------------- |:-------------------------------------------------------- |
6
+ | `--video_path` | `None` | Input video file path |
7
+ | `--out_dir` | `./experiments/`| Output directory |
8
+ | `--device` | `cuda:0` | The device to use (e.g., CPU or GPU) |
9
+ | `--exp_name` | `None` | Experiment name, defaults to video file name |
10
+ | `--seed` | `43` | Random seed for reproducibility |
11
+ | `--video_length` | `49` | Length of the video frames (number of frames) |
12
+ | `--fps` | `10` | fps for saved video |
13
+ | `--stride` | `1` | Sampling stride for input video (frame interval) |
14
+ | `--server_name` | `None` | Server IP address for gradio |
15
+ ### 2. Point cloud render configs
16
+
17
+ | Configuration | Default Value | Explanation |
18
+ |:----------------- |:--------------- |:-------------------------------------------------------- |
19
+ | `--radius_scale` | `1.0` | Scale factor for the spherical radius |
20
+ | `--camera` | `traj` | Camera pose type, either 'traj' or 'target' |
21
+ | `--mode` | `gradual` | Mode of operation, 'gradual', 'bullet', or 'direct' |
22
+ | `--mask` | `False` | Clean the point cloud data if true |
23
+ | `--target_pose` | `None` | Required for 'target' camera pose type, specifies a relative camera pose sequece (theta, phi, r, x, y). +theta (theta<50) rotates camera upward, +phi (phi<50) rotates camera to right, +r (r<0.6) moves camera forward, +x (x<4) pans the camera to right, +y (y<4) pans the camera upward |
24
+ | `--traj_txt` | `None` | Required for 'traj' camera pose type, a txt file specifying a complex camera trajectory ([examples](../test/trajs/loop1.txt)). The fist line is the theta sequence, the second line the phi sequence, and the last line the r sequence |
25
+ | `--near` | `0.0001` | Near clipping plane distance |
26
+ | `--far` | `10000.0` | Far clipping plane distance |
27
+ | `--anchor_idx` | `0` | One GT frame for anchor frame |
extern/depthcrafter/__init__.py ADDED
File without changes
extern/depthcrafter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
extern/depthcrafter/__pycache__/demo.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
extern/depthcrafter/__pycache__/depth_crafter_ppl.cpython-310.pyc ADDED
Binary file (7.88 kB). View file
 
extern/depthcrafter/__pycache__/infer.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
extern/depthcrafter/__pycache__/unet.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
extern/depthcrafter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.29 kB). View file
 
extern/depthcrafter/depth_crafter_ppl.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
7
+ _resize_with_antialiasing,
8
+ StableVideoDiffusionPipelineOutput,
9
+ StableVideoDiffusionPipeline,
10
+ retrieve_timesteps,
11
+ )
12
+ from diffusers.utils import logging
13
+ from diffusers.utils.torch_utils import randn_tensor
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ class DepthCrafterPipeline(StableVideoDiffusionPipeline):
19
+
20
+ @torch.inference_mode()
21
+ def encode_video(
22
+ self,
23
+ video: torch.Tensor,
24
+ chunk_size: int = 14,
25
+ ) -> torch.Tensor:
26
+ """
27
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
28
+ :param chunk_size: the chunk size to encode video
29
+ :return: image_embeddings in shape of [b, 1024]
30
+ """
31
+
32
+ video_224 = _resize_with_antialiasing(video.float(), (224, 224))
33
+ video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1]
34
+
35
+ embeddings = []
36
+ for i in range(0, video_224.shape[0], chunk_size):
37
+ tmp = self.feature_extractor(
38
+ images=video_224[i : i + chunk_size],
39
+ do_normalize=True,
40
+ do_center_crop=False,
41
+ do_resize=False,
42
+ do_rescale=False,
43
+ return_tensors="pt",
44
+ ).pixel_values.to(video.device, dtype=video.dtype)
45
+ embeddings.append(self.image_encoder(tmp).image_embeds) # [b, 1024]
46
+
47
+ embeddings = torch.cat(embeddings, dim=0) # [t, 1024]
48
+ return embeddings
49
+
50
+ @torch.inference_mode()
51
+ def encode_vae_video(
52
+ self,
53
+ video: torch.Tensor,
54
+ chunk_size: int = 14,
55
+ ):
56
+ """
57
+ :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
58
+ :param chunk_size: the chunk size to encode video
59
+ :return: vae latents in shape of [b, c, h, w]
60
+ """
61
+ video_latents = []
62
+ for i in range(0, video.shape[0], chunk_size):
63
+ video_latents.append(
64
+ self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
65
+ )
66
+ video_latents = torch.cat(video_latents, dim=0)
67
+ return video_latents
68
+
69
+ @staticmethod
70
+ def check_inputs(video, height, width):
71
+ """
72
+ :param video:
73
+ :param height:
74
+ :param width:
75
+ :return:
76
+ """
77
+ if not isinstance(video, torch.Tensor) and not isinstance(video, np.ndarray):
78
+ raise ValueError(
79
+ f"Expected `video` to be a `torch.Tensor` or `VideoReader`, but got a {type(video)}"
80
+ )
81
+
82
+ if height % 8 != 0 or width % 8 != 0:
83
+ raise ValueError(
84
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
85
+ )
86
+
87
+ @torch.no_grad()
88
+ def __call__(
89
+ self,
90
+ video: Union[np.ndarray, torch.Tensor],
91
+ height: int = 576,
92
+ width: int = 1024,
93
+ num_inference_steps: int = 25,
94
+ guidance_scale: float = 1.0,
95
+ window_size: Optional[int] = 110,
96
+ noise_aug_strength: float = 0.02,
97
+ decode_chunk_size: Optional[int] = None,
98
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
99
+ latents: Optional[torch.FloatTensor] = None,
100
+ output_type: Optional[str] = "pil",
101
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
102
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
103
+ return_dict: bool = True,
104
+ overlap: int = 25,
105
+ track_time: bool = False,
106
+ ):
107
+ """
108
+ :param video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
109
+ :param height:
110
+ :param width:
111
+ :param num_inference_steps:
112
+ :param guidance_scale:
113
+ :param window_size: sliding window processing size
114
+ :param fps:
115
+ :param motion_bucket_id:
116
+ :param noise_aug_strength:
117
+ :param decode_chunk_size:
118
+ :param generator:
119
+ :param latents:
120
+ :param output_type:
121
+ :param callback_on_step_end:
122
+ :param callback_on_step_end_tensor_inputs:
123
+ :param return_dict:
124
+ :return:
125
+ """
126
+ # 0. Default height and width to unet
127
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
128
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
129
+ num_frames = video.shape[0]
130
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
131
+ if num_frames <= window_size:
132
+ window_size = num_frames
133
+ overlap = 0
134
+ stride = window_size - overlap
135
+
136
+ # 1. Check inputs. Raise error if not correct
137
+ self.check_inputs(video, height, width)
138
+
139
+ # 2. Define call parameters
140
+ batch_size = 1
141
+ device = self._execution_device
142
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
143
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
144
+ # corresponds to doing no classifier free guidance.
145
+ self._guidance_scale = guidance_scale
146
+
147
+ # 3. Encode input video
148
+ if isinstance(video, np.ndarray):
149
+ video = torch.from_numpy(video.transpose(0, 3, 1, 2))
150
+ else:
151
+ assert isinstance(video, torch.Tensor)
152
+ video = video.to(device=device, dtype=self.dtype)
153
+ video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w]
154
+
155
+ if track_time:
156
+ start_event = torch.cuda.Event(enable_timing=True)
157
+ encode_event = torch.cuda.Event(enable_timing=True)
158
+ denoise_event = torch.cuda.Event(enable_timing=True)
159
+ decode_event = torch.cuda.Event(enable_timing=True)
160
+ start_event.record()
161
+
162
+ video_embeddings = self.encode_video(
163
+ video, chunk_size=decode_chunk_size
164
+ ).unsqueeze(
165
+ 0
166
+ ) # [1, t, 1024]
167
+ torch.cuda.empty_cache()
168
+ # 4. Encode input image using VAE
169
+ noise = randn_tensor(
170
+ video.shape, generator=generator, device=device, dtype=video.dtype
171
+ )
172
+ video = video + noise_aug_strength * noise # in [t, c, h, w]
173
+
174
+ # pdb.set_trace()
175
+ needs_upcasting = (
176
+ self.vae.dtype == torch.float16 and self.vae.config.force_upcast
177
+ )
178
+ if needs_upcasting:
179
+ self.vae.to(dtype=torch.float32)
180
+
181
+ video_latents = self.encode_vae_video(
182
+ video.to(self.vae.dtype),
183
+ chunk_size=decode_chunk_size,
184
+ ).unsqueeze(
185
+ 0
186
+ ) # [1, t, c, h, w]
187
+
188
+ if track_time:
189
+ encode_event.record()
190
+ torch.cuda.synchronize()
191
+ elapsed_time_ms = start_event.elapsed_time(encode_event)
192
+ print(f"Elapsed time for encoding video: {elapsed_time_ms} ms")
193
+
194
+ torch.cuda.empty_cache()
195
+
196
+ # cast back to fp16 if needed
197
+ if needs_upcasting:
198
+ self.vae.to(dtype=torch.float16)
199
+
200
+ # 5. Get Added Time IDs
201
+ added_time_ids = self._get_add_time_ids(
202
+ 7,
203
+ 127,
204
+ noise_aug_strength,
205
+ video_embeddings.dtype,
206
+ batch_size,
207
+ 1,
208
+ False,
209
+ ) # [1 or 2, 3]
210
+ added_time_ids = added_time_ids.to(device)
211
+
212
+ # 6. Prepare timesteps
213
+ timesteps, num_inference_steps = retrieve_timesteps(
214
+ self.scheduler, num_inference_steps, device, None, None
215
+ )
216
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
217
+ self._num_timesteps = len(timesteps)
218
+
219
+ # 7. Prepare latent variables
220
+ num_channels_latents = self.unet.config.in_channels
221
+ latents_init = self.prepare_latents(
222
+ batch_size,
223
+ window_size,
224
+ num_channels_latents,
225
+ height,
226
+ width,
227
+ video_embeddings.dtype,
228
+ device,
229
+ generator,
230
+ latents,
231
+ ) # [1, t, c, h, w]
232
+ latents_all = None
233
+
234
+ idx_start = 0
235
+ if overlap > 0:
236
+ weights = torch.linspace(0, 1, overlap, device=device)
237
+ weights = weights.view(1, overlap, 1, 1, 1)
238
+ else:
239
+ weights = None
240
+
241
+ torch.cuda.empty_cache()
242
+
243
+ # inference strategy for long videos
244
+ # two main strategies: 1. noise init from previous frame, 2. segments stitching
245
+ while idx_start < num_frames - overlap:
246
+ idx_end = min(idx_start + window_size, num_frames)
247
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
248
+
249
+ # 9. Denoising loop
250
+ latents = latents_init[:, : idx_end - idx_start].clone()
251
+ latents_init = torch.cat(
252
+ [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
253
+ )
254
+
255
+ video_latents_current = video_latents[:, idx_start:idx_end]
256
+ video_embeddings_current = video_embeddings[:, idx_start:idx_end]
257
+
258
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
259
+ for i, t in enumerate(timesteps):
260
+ if latents_all is not None and i == 0:
261
+ latents[:, :overlap] = (
262
+ latents_all[:, -overlap:]
263
+ + latents[:, :overlap]
264
+ / self.scheduler.init_noise_sigma
265
+ * self.scheduler.sigmas[i]
266
+ )
267
+
268
+ latent_model_input = latents # [1, t, c, h, w]
269
+ latent_model_input = self.scheduler.scale_model_input(
270
+ latent_model_input, t
271
+ ) # [1, t, c, h, w]
272
+ latent_model_input = torch.cat(
273
+ [latent_model_input, video_latents_current], dim=2
274
+ )
275
+ noise_pred = self.unet(
276
+ latent_model_input,
277
+ t,
278
+ encoder_hidden_states=video_embeddings_current,
279
+ added_time_ids=added_time_ids,
280
+ return_dict=False,
281
+ )[0]
282
+ # perform guidance
283
+ if self.do_classifier_free_guidance:
284
+ latent_model_input = latents
285
+ latent_model_input = self.scheduler.scale_model_input(
286
+ latent_model_input, t
287
+ )
288
+ latent_model_input = torch.cat(
289
+ [latent_model_input, torch.zeros_like(latent_model_input)],
290
+ dim=2,
291
+ )
292
+ noise_pred_uncond = self.unet(
293
+ latent_model_input,
294
+ t,
295
+ encoder_hidden_states=torch.zeros_like(
296
+ video_embeddings_current
297
+ ),
298
+ added_time_ids=added_time_ids,
299
+ return_dict=False,
300
+ )[0]
301
+
302
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
303
+ noise_pred - noise_pred_uncond
304
+ )
305
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
306
+
307
+ if callback_on_step_end is not None:
308
+ callback_kwargs = {}
309
+ for k in callback_on_step_end_tensor_inputs:
310
+ callback_kwargs[k] = locals()[k]
311
+ callback_outputs = callback_on_step_end(
312
+ self, i, t, callback_kwargs
313
+ )
314
+
315
+ latents = callback_outputs.pop("latents", latents)
316
+
317
+ if i == len(timesteps) - 1 or (
318
+ (i + 1) > num_warmup_steps
319
+ and (i + 1) % self.scheduler.order == 0
320
+ ):
321
+ progress_bar.update()
322
+
323
+ if latents_all is None:
324
+ latents_all = latents.clone()
325
+ else:
326
+ assert weights is not None
327
+ # latents_all[:, -overlap:] = (
328
+ # latents[:, :overlap] + latents_all[:, -overlap:]
329
+ # ) / 2.0
330
+ latents_all[:, -overlap:] = latents[
331
+ :, :overlap
332
+ ] * weights + latents_all[:, -overlap:] * (1 - weights)
333
+ latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)
334
+
335
+ idx_start += stride
336
+
337
+ if track_time:
338
+ denoise_event.record()
339
+ torch.cuda.synchronize()
340
+ elapsed_time_ms = encode_event.elapsed_time(denoise_event)
341
+ print(f"Elapsed time for denoising video: {elapsed_time_ms} ms")
342
+
343
+ if not output_type == "latent":
344
+ # cast back to fp16 if needed
345
+ if needs_upcasting:
346
+ self.vae.to(dtype=torch.float16)
347
+ frames = self.decode_latents(latents_all, num_frames, decode_chunk_size)
348
+
349
+ if track_time:
350
+ decode_event.record()
351
+ torch.cuda.synchronize()
352
+ elapsed_time_ms = denoise_event.elapsed_time(decode_event)
353
+ print(f"Elapsed time for decoding video: {elapsed_time_ms} ms")
354
+
355
+ frames = self.video_processor.postprocess_video(
356
+ video=frames, output_type=output_type
357
+ )
358
+ else:
359
+ frames = latents_all
360
+
361
+ self.maybe_free_model_hooks()
362
+
363
+ if not return_dict:
364
+ return frames
365
+
366
+ return StableVideoDiffusionPipelineOutput(frames=frames)
extern/depthcrafter/infer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers.training_utils import set_seed
7
+ from extern.depthcrafter.depth_crafter_ppl import DepthCrafterPipeline
8
+ from extern.depthcrafter.unet import DiffusersUNetSpatioTemporalConditionModelDepthCrafter
9
+
10
+
11
+ class DepthCrafterDemo:
12
+ def __init__(
13
+ self,
14
+ unet_path: str,
15
+ pre_train_path: str,
16
+ cpu_offload: str = "model",
17
+ device: str = "cuda:0"
18
+ ):
19
+ unet = DiffusersUNetSpatioTemporalConditionModelDepthCrafter.from_pretrained(
20
+ unet_path,
21
+ low_cpu_mem_usage=True,
22
+ torch_dtype=torch.float16,
23
+ )
24
+ # load weights of other components from the provided checkpoint
25
+ self.pipe = DepthCrafterPipeline.from_pretrained(
26
+ pre_train_path,
27
+ unet=unet,
28
+ torch_dtype=torch.float16,
29
+ variant="fp16",
30
+ )
31
+
32
+ # for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
33
+ if cpu_offload is not None:
34
+ if cpu_offload == "sequential":
35
+ # This will slow, but save more memory
36
+ self.pipe.enable_sequential_cpu_offload()
37
+ elif cpu_offload == "model":
38
+ self.pipe.enable_model_cpu_offload()
39
+ else:
40
+ raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
41
+ else:
42
+ self.pipe.to(device)
43
+ # enable attention slicing and xformers memory efficient attention
44
+ try:
45
+ self.pipe.enable_xformers_memory_efficient_attention()
46
+ except Exception as e:
47
+ print(e)
48
+ print("Xformers is not enabled")
49
+ self.pipe.enable_attention_slicing()
50
+
51
+ def infer(
52
+ self,
53
+ frames,
54
+ near,
55
+ far,
56
+ num_denoising_steps: int,
57
+ guidance_scale: float,
58
+ window_size: int = 110,
59
+ overlap: int = 25,
60
+ seed: int = 42,
61
+ track_time: bool = True,
62
+ ):
63
+ set_seed(seed)
64
+
65
+ # inference the depth map using the DepthCrafter pipeline
66
+ with torch.inference_mode():
67
+ res = self.pipe(
68
+ frames,
69
+ height=frames.shape[1],
70
+ width=frames.shape[2],
71
+ output_type="np",
72
+ guidance_scale=guidance_scale,
73
+ num_inference_steps=num_denoising_steps,
74
+ window_size=window_size,
75
+ overlap=overlap,
76
+ track_time=track_time,
77
+ ).frames[0]
78
+ # convert the three-channel output to a single channel depth map
79
+ res = res.sum(-1) / res.shape[-1]
80
+ # normalize the depth map to [0, 1] across the whole video
81
+ depths = (res - res.min()) / (res.max() - res.min())
82
+ # visualize the depth map and save the results
83
+ # vis = vis_sequence_depth(res)
84
+ # save the depth map and visualization with the target FPS
85
+ depths = torch.from_numpy(depths).unsqueeze(1) # 49 576 1024 ->
86
+ depths *= 3900 # compatible with da output
87
+ depths[depths < 1e-5] = 1e-5
88
+ depths = 10000. / depths
89
+ depths = depths.clip(near, far)
90
+
91
+ return depths
extern/depthcrafter/unet.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple
2
+
3
+ import torch
4
+ from diffusers import UNetSpatioTemporalConditionModel
5
+ from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
6
+
7
+
8
+ class DiffusersUNetSpatioTemporalConditionModelDepthCrafter(
9
+ UNetSpatioTemporalConditionModel
10
+ ):
11
+
12
+ def forward(
13
+ self,
14
+ sample: torch.Tensor,
15
+ timestep: Union[torch.Tensor, float, int],
16
+ encoder_hidden_states: torch.Tensor,
17
+ added_time_ids: torch.Tensor,
18
+ return_dict: bool = True,
19
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
20
+
21
+ # 1. time
22
+ timesteps = timestep
23
+ if not torch.is_tensor(timesteps):
24
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
25
+ # This would be a good case for the `match` statement (Python 3.10+)
26
+ is_mps = sample.device.type == "mps"
27
+ if isinstance(timestep, float):
28
+ dtype = torch.float32 if is_mps else torch.float64
29
+ else:
30
+ dtype = torch.int32 if is_mps else torch.int64
31
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
32
+ elif len(timesteps.shape) == 0:
33
+ timesteps = timesteps[None].to(sample.device)
34
+
35
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
36
+ batch_size, num_frames = sample.shape[:2]
37
+ timesteps = timesteps.expand(batch_size)
38
+
39
+ t_emb = self.time_proj(timesteps)
40
+
41
+ # `Timesteps` does not contain any weights and will always return f32 tensors
42
+ # but time_embedding might actually be running in fp16. so we need to cast here.
43
+ # there might be better ways to encapsulate this.
44
+ t_emb = t_emb.to(dtype=self.conv_in.weight.dtype)
45
+
46
+ emb = self.time_embedding(t_emb) # [batch_size * num_frames, channels]
47
+
48
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
49
+ time_embeds = time_embeds.reshape((batch_size, -1))
50
+ time_embeds = time_embeds.to(emb.dtype)
51
+ aug_emb = self.add_embedding(time_embeds)
52
+ emb = emb + aug_emb
53
+
54
+ # Flatten the batch and frames dimensions
55
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
56
+ sample = sample.flatten(0, 1)
57
+ # Repeat the embeddings num_video_frames times
58
+ # emb: [batch, channels] -> [batch * frames, channels]
59
+ emb = emb.repeat_interleave(num_frames, dim=0)
60
+ # encoder_hidden_states: [batch, frames, channels] -> [batch * frames, 1, channels]
61
+ encoder_hidden_states = encoder_hidden_states.flatten(0, 1).unsqueeze(1)
62
+
63
+ # 2. pre-process
64
+ sample = sample.to(dtype=self.conv_in.weight.dtype)
65
+ assert sample.dtype == self.conv_in.weight.dtype, (
66
+ f"sample.dtype: {sample.dtype}, "
67
+ f"self.conv_in.weight.dtype: {self.conv_in.weight.dtype}"
68
+ )
69
+ sample = self.conv_in(sample)
70
+
71
+ image_only_indicator = torch.zeros(
72
+ batch_size, num_frames, dtype=sample.dtype, device=sample.device
73
+ )
74
+
75
+ down_block_res_samples = (sample,)
76
+ for downsample_block in self.down_blocks:
77
+ if (
78
+ hasattr(downsample_block, "has_cross_attention")
79
+ and downsample_block.has_cross_attention
80
+ ):
81
+ sample, res_samples = downsample_block(
82
+ hidden_states=sample,
83
+ temb=emb,
84
+ encoder_hidden_states=encoder_hidden_states,
85
+ image_only_indicator=image_only_indicator,
86
+ )
87
+
88
+ else:
89
+ sample, res_samples = downsample_block(
90
+ hidden_states=sample,
91
+ temb=emb,
92
+ image_only_indicator=image_only_indicator,
93
+ )
94
+
95
+ down_block_res_samples += res_samples
96
+
97
+ # 4. mid
98
+ sample = self.mid_block(
99
+ hidden_states=sample,
100
+ temb=emb,
101
+ encoder_hidden_states=encoder_hidden_states,
102
+ image_only_indicator=image_only_indicator,
103
+ )
104
+
105
+ # 5. up
106
+ for i, upsample_block in enumerate(self.up_blocks):
107
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
108
+ down_block_res_samples = down_block_res_samples[
109
+ : -len(upsample_block.resnets)
110
+ ]
111
+
112
+ if (
113
+ hasattr(upsample_block, "has_cross_attention")
114
+ and upsample_block.has_cross_attention
115
+ ):
116
+ sample = upsample_block(
117
+ hidden_states=sample,
118
+ res_hidden_states_tuple=res_samples,
119
+ temb=emb,
120
+ encoder_hidden_states=encoder_hidden_states,
121
+ image_only_indicator=image_only_indicator,
122
+ )
123
+ else:
124
+ sample = upsample_block(
125
+ hidden_states=sample,
126
+ res_hidden_states_tuple=res_samples,
127
+ temb=emb,
128
+ image_only_indicator=image_only_indicator,
129
+ )
130
+
131
+ # 6. post-process
132
+ sample = self.conv_norm_out(sample)
133
+ sample = self.conv_act(sample)
134
+ sample = self.conv_out(sample)
135
+
136
+ # 7. Reshape back to original shape
137
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
138
+
139
+ if not return_dict:
140
+ return (sample,)
141
+
142
+ return UNetSpatioTemporalConditionOutput(sample=sample)
extern/video_depth_anything/__pycache__/dinov2.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
extern/video_depth_anything/__pycache__/dpt.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
extern/video_depth_anything/__pycache__/dpt_temporal.cpython-310.pyc ADDED
Binary file (2.76 kB). View file
 
extern/video_depth_anything/__pycache__/vdademo.cpython-310.pyc ADDED
Binary file (1.52 kB). View file
 
extern/video_depth_anything/__pycache__/video_depth.cpython-310.pyc ADDED
Binary file (4.66 kB). View file
 
extern/video_depth_anything/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
extern/video_depth_anything/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
extern/video_depth_anything/dinov2_layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (415 Bytes). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/attention.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/block.cpython-310.pyc ADDED
Binary file (7.99 kB). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/drop_path.cpython-310.pyc ADDED
Binary file (1.22 kB). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc ADDED
Binary file (1.02 kB). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/mlp.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc ADDED
Binary file (2.66 kB). View file
 
extern/video_depth_anything/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
extern/video_depth_anything/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
extern/video_depth_anything/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
extern/video_depth_anything/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
extern/video_depth_anything/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
extern/video_depth_anything/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
extern/video_depth_anything/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
extern/video_depth_anything/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
extern/video_depth_anything/dpt.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from .util.blocks import FeatureFusionBlock, _make_scratch
19
+
20
+
21
+ def _make_fusion_block(features, use_bn, size=None):
22
+ return FeatureFusionBlock(
23
+ features,
24
+ nn.ReLU(False),
25
+ deconv=False,
26
+ bn=use_bn,
27
+ expand=False,
28
+ align_corners=True,
29
+ size=size,
30
+ )
31
+
32
+
33
+ class ConvBlock(nn.Module):
34
+ def __init__(self, in_feature, out_feature):
35
+ super().__init__()
36
+
37
+ self.conv_block = nn.Sequential(
38
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
39
+ nn.BatchNorm2d(out_feature),
40
+ nn.ReLU(True)
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.conv_block(x)
45
+
46
+
47
+ class DPTHead(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_channels,
51
+ features=256,
52
+ use_bn=False,
53
+ out_channels=[256, 512, 1024, 1024],
54
+ use_clstoken=False
55
+ ):
56
+ super(DPTHead, self).__init__()
57
+
58
+ self.use_clstoken = use_clstoken
59
+
60
+ self.projects = nn.ModuleList([
61
+ nn.Conv2d(
62
+ in_channels=in_channels,
63
+ out_channels=out_channel,
64
+ kernel_size=1,
65
+ stride=1,
66
+ padding=0,
67
+ ) for out_channel in out_channels
68
+ ])
69
+
70
+ self.resize_layers = nn.ModuleList([
71
+ nn.ConvTranspose2d(
72
+ in_channels=out_channels[0],
73
+ out_channels=out_channels[0],
74
+ kernel_size=4,
75
+ stride=4,
76
+ padding=0),
77
+ nn.ConvTranspose2d(
78
+ in_channels=out_channels[1],
79
+ out_channels=out_channels[1],
80
+ kernel_size=2,
81
+ stride=2,
82
+ padding=0),
83
+ nn.Identity(),
84
+ nn.Conv2d(
85
+ in_channels=out_channels[3],
86
+ out_channels=out_channels[3],
87
+ kernel_size=3,
88
+ stride=2,
89
+ padding=1)
90
+ ])
91
+
92
+ if use_clstoken:
93
+ self.readout_projects = nn.ModuleList()
94
+ for _ in range(len(self.projects)):
95
+ self.readout_projects.append(
96
+ nn.Sequential(
97
+ nn.Linear(2 * in_channels, in_channels),
98
+ nn.GELU()))
99
+
100
+ self.scratch = _make_scratch(
101
+ out_channels,
102
+ features,
103
+ groups=1,
104
+ expand=False,
105
+ )
106
+
107
+ self.scratch.stem_transpose = None
108
+
109
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
110
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
111
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
112
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
113
+
114
+ head_features_1 = features
115
+ head_features_2 = 32
116
+
117
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
118
+ self.scratch.output_conv2 = nn.Sequential(
119
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
120
+ nn.ReLU(True),
121
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
122
+ nn.ReLU(True),
123
+ nn.Identity(),
124
+ )
125
+
126
+ def forward(self, out_features, patch_h, patch_w):
127
+ out = []
128
+ for i, x in enumerate(out_features):
129
+ if self.use_clstoken:
130
+ x, cls_token = x[0], x[1]
131
+ readout = cls_token.unsqueeze(1).expand_as(x)
132
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
133
+ else:
134
+ x = x[0]
135
+
136
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
137
+
138
+ x = self.projects[i](x)
139
+ x = self.resize_layers[i](x)
140
+
141
+ out.append(x)
142
+
143
+ layer_1, layer_2, layer_3, layer_4 = out
144
+
145
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
146
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
147
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
148
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
149
+
150
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
151
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
152
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
153
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
154
+
155
+ out = self.scratch.output_conv1(path_1)
156
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
157
+ out = self.scratch.output_conv2(out)
158
+
159
+ return out
160
+
extern/video_depth_anything/dpt_temporal.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.nn as nn
17
+ from .dpt import DPTHead
18
+ from .motion_module.motion_module import TemporalModule
19
+ from easydict import EasyDict
20
+
21
+
22
+ class DPTHeadTemporal(DPTHead):
23
+ def __init__(self,
24
+ in_channels,
25
+ features=256,
26
+ use_bn=False,
27
+ out_channels=[256, 512, 1024, 1024],
28
+ use_clstoken=False,
29
+ num_frames=32,
30
+ pe='ape'
31
+ ):
32
+ super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
33
+
34
+ assert num_frames > 0
35
+ motion_module_kwargs = EasyDict(num_attention_heads = 8,
36
+ num_transformer_block = 1,
37
+ num_attention_blocks = 2,
38
+ temporal_max_len = num_frames,
39
+ zero_initialize = True,
40
+ pos_embedding_type = pe)
41
+
42
+ self.motion_modules = nn.ModuleList([
43
+ TemporalModule(in_channels=out_channels[2],
44
+ **motion_module_kwargs),
45
+ TemporalModule(in_channels=out_channels[3],
46
+ **motion_module_kwargs),
47
+ TemporalModule(in_channels=features,
48
+ **motion_module_kwargs),
49
+ TemporalModule(in_channels=features,
50
+ **motion_module_kwargs)
51
+ ])
52
+
53
+ def forward(self, out_features, patch_h, patch_w, frame_length):
54
+ out = []
55
+ for i, x in enumerate(out_features):
56
+ if self.use_clstoken:
57
+ x, cls_token = x[0], x[1]
58
+ readout = cls_token.unsqueeze(1).expand_as(x)
59
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
60
+ else:
61
+ x = x[0]
62
+
63
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
64
+
65
+ B, T = x.shape[0] // frame_length, frame_length
66
+ x = self.projects[i](x)
67
+ x = self.resize_layers[i](x)
68
+
69
+ out.append(x)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = out
72
+
73
+ B, T = layer_1.shape[0] // frame_length, frame_length
74
+
75
+ layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
76
+ layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
77
+
78
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
79
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
80
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
81
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
82
+
83
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
84
+ path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
85
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
86
+ path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
87
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
88
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
89
+
90
+ out = self.scratch.output_conv1(path_1)
91
+ out = F.interpolate(
92
+ out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
93
+ )
94
+ out = self.scratch.output_conv2(out)
95
+
96
+ return out
extern/video_depth_anything/motion_module/__pycache__/attention.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
extern/video_depth_anything/motion_module/__pycache__/motion_module.cpython-310.pyc ADDED
Binary file (7.39 kB). View file
 
extern/video_depth_anything/motion_module/attention.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ try:
21
+ import xformers
22
+ import xformers.ops
23
+
24
+ XFORMERS_AVAILABLE = True
25
+ except ImportError:
26
+ print("xFormers not available")
27
+ XFORMERS_AVAILABLE = False
28
+
29
+
30
+ class CrossAttention(nn.Module):
31
+ r"""
32
+ A cross attention layer.
33
+
34
+ Parameters:
35
+ query_dim (`int`): The number of channels in the query.
36
+ cross_attention_dim (`int`, *optional*):
37
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
38
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
39
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
40
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
41
+ bias (`bool`, *optional*, defaults to False):
42
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ query_dim: int,
48
+ cross_attention_dim: Optional[int] = None,
49
+ heads: int = 8,
50
+ dim_head: int = 64,
51
+ dropout: float = 0.0,
52
+ bias=False,
53
+ upcast_attention: bool = False,
54
+ upcast_softmax: bool = False,
55
+ added_kv_proj_dim: Optional[int] = None,
56
+ norm_num_groups: Optional[int] = None,
57
+ ):
58
+ super().__init__()
59
+ inner_dim = dim_head * heads
60
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
61
+ self.upcast_attention = upcast_attention
62
+ self.upcast_softmax = upcast_softmax
63
+ self.upcast_efficient_attention = False
64
+
65
+ self.scale = dim_head**-0.5
66
+
67
+ self.heads = heads
68
+ # for slice_size > 0 the attention score computation
69
+ # is split across the batch axis to save memory
70
+ # You can set slice_size with `set_attention_slice`
71
+ self.sliceable_head_dim = heads
72
+ self._slice_size = None
73
+ self._use_memory_efficient_attention_xformers = False
74
+ self.added_kv_proj_dim = added_kv_proj_dim
75
+
76
+ if norm_num_groups is not None:
77
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
78
+ else:
79
+ self.group_norm = None
80
+
81
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
82
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
83
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
84
+
85
+ if self.added_kv_proj_dim is not None:
86
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
87
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
88
+
89
+ self.to_out = nn.ModuleList([])
90
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
91
+ self.to_out.append(nn.Dropout(dropout))
92
+
93
+ def reshape_heads_to_batch_dim(self, tensor):
94
+ batch_size, seq_len, dim = tensor.shape
95
+ head_size = self.heads
96
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
97
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
98
+ return tensor
99
+
100
+ def reshape_heads_to_4d(self, tensor):
101
+ batch_size, seq_len, dim = tensor.shape
102
+ head_size = self.heads
103
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
104
+ return tensor
105
+
106
+ def reshape_batch_dim_to_heads(self, tensor):
107
+ batch_size, seq_len, dim = tensor.shape
108
+ head_size = self.heads
109
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
110
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
111
+ return tensor
112
+
113
+ def reshape_4d_to_heads(self, tensor):
114
+ batch_size, seq_len, head_size, dim = tensor.shape
115
+ head_size = self.heads
116
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
117
+ return tensor
118
+
119
+ def set_attention_slice(self, slice_size):
120
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
121
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
122
+
123
+ self._slice_size = slice_size
124
+
125
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
126
+ batch_size, sequence_length, _ = hidden_states.shape
127
+
128
+ encoder_hidden_states = encoder_hidden_states
129
+
130
+ if self.group_norm is not None:
131
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = self.to_q(hidden_states)
134
+ dim = query.shape[-1]
135
+ query = self.reshape_heads_to_batch_dim(query)
136
+
137
+ if self.added_kv_proj_dim is not None:
138
+ key = self.to_k(hidden_states)
139
+ value = self.to_v(hidden_states)
140
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
141
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
142
+
143
+ key = self.reshape_heads_to_batch_dim(key)
144
+ value = self.reshape_heads_to_batch_dim(value)
145
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
146
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
147
+
148
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
149
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
150
+ else:
151
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
152
+ key = self.to_k(encoder_hidden_states)
153
+ value = self.to_v(encoder_hidden_states)
154
+
155
+ key = self.reshape_heads_to_batch_dim(key)
156
+ value = self.reshape_heads_to_batch_dim(value)
157
+
158
+ if attention_mask is not None:
159
+ if attention_mask.shape[-1] != query.shape[1]:
160
+ target_length = query.shape[1]
161
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
162
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
163
+
164
+ # attention, what we cannot get enough of
165
+ if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:
166
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
167
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
168
+ hidden_states = hidden_states.to(query.dtype)
169
+ else:
170
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
171
+ hidden_states = self._attention(query, key, value, attention_mask)
172
+ else:
173
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
174
+
175
+ # linear proj
176
+ hidden_states = self.to_out[0](hidden_states)
177
+
178
+ # dropout
179
+ hidden_states = self.to_out[1](hidden_states)
180
+ return hidden_states
181
+
182
+ def _attention(self, query, key, value, attention_mask=None):
183
+ if self.upcast_attention:
184
+ query = query.float()
185
+ key = key.float()
186
+
187
+ attention_scores = torch.baddbmm(
188
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
189
+ query,
190
+ key.transpose(-1, -2),
191
+ beta=0,
192
+ alpha=self.scale,
193
+ )
194
+
195
+ if attention_mask is not None:
196
+ attention_scores = attention_scores + attention_mask
197
+
198
+ if self.upcast_softmax:
199
+ attention_scores = attention_scores.float()
200
+
201
+ attention_probs = attention_scores.softmax(dim=-1)
202
+
203
+ # cast back to the original dtype
204
+ attention_probs = attention_probs.to(value.dtype)
205
+
206
+ # compute attention output
207
+ hidden_states = torch.bmm(attention_probs, value)
208
+
209
+ # reshape hidden_states
210
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
211
+ return hidden_states
212
+
213
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
214
+ batch_size_attention = query.shape[0]
215
+ hidden_states = torch.zeros(
216
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
217
+ )
218
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
219
+ for i in range(hidden_states.shape[0] // slice_size):
220
+ start_idx = i * slice_size
221
+ end_idx = (i + 1) * slice_size
222
+
223
+ query_slice = query[start_idx:end_idx]
224
+ key_slice = key[start_idx:end_idx]
225
+
226
+ if self.upcast_attention:
227
+ query_slice = query_slice.float()
228
+ key_slice = key_slice.float()
229
+
230
+ attn_slice = torch.baddbmm(
231
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
232
+ query_slice,
233
+ key_slice.transpose(-1, -2),
234
+ beta=0,
235
+ alpha=self.scale,
236
+ )
237
+
238
+ if attention_mask is not None:
239
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
240
+
241
+ if self.upcast_softmax:
242
+ attn_slice = attn_slice.float()
243
+
244
+ attn_slice = attn_slice.softmax(dim=-1)
245
+
246
+ # cast back to the original dtype
247
+ attn_slice = attn_slice.to(value.dtype)
248
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
249
+
250
+ hidden_states[start_idx:end_idx] = attn_slice
251
+
252
+ # reshape hidden_states
253
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
254
+ return hidden_states
255
+
256
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
257
+ if self.upcast_efficient_attention:
258
+ org_dtype = query.dtype
259
+ query = query.float()
260
+ key = key.float()
261
+ value = value.float()
262
+ if attention_mask is not None:
263
+ attention_mask = attention_mask.float()
264
+ hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
265
+
266
+ if self.upcast_efficient_attention:
267
+ hidden_states = hidden_states.to(org_dtype)
268
+
269
+ hidden_states = self.reshape_4d_to_heads(hidden_states)
270
+ return hidden_states
271
+
272
+ # print("Errror: no xformers")
273
+ # raise NotImplementedError
274
+
275
+ def _memory_efficient_attention_split(self, query, key, value, attention_mask):
276
+ batch_size = query.shape[0]
277
+ max_batch_size = 65535
278
+ num_batches = (batch_size + max_batch_size - 1) // max_batch_size
279
+ results = []
280
+ for i in range(num_batches):
281
+ start_idx = i * max_batch_size
282
+ end_idx = min((i + 1) * max_batch_size, batch_size)
283
+ query_batch = query[start_idx:end_idx]
284
+ key_batch = key[start_idx:end_idx]
285
+ value_batch = value[start_idx:end_idx]
286
+ if attention_mask is not None:
287
+ attention_mask_batch = attention_mask[start_idx:end_idx]
288
+ else:
289
+ attention_mask_batch = None
290
+ result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
291
+ results.append(result)
292
+ full_result = torch.cat(results, dim=0)
293
+ return full_result
294
+
295
+
296
+ class FeedForward(nn.Module):
297
+ r"""
298
+ A feed-forward layer.
299
+
300
+ Parameters:
301
+ dim (`int`): The number of channels in the input.
302
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
303
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
304
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
305
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ dim: int,
311
+ dim_out: Optional[int] = None,
312
+ mult: int = 4,
313
+ dropout: float = 0.0,
314
+ activation_fn: str = "geglu",
315
+ ):
316
+ super().__init__()
317
+ inner_dim = int(dim * mult)
318
+ dim_out = dim_out if dim_out is not None else dim
319
+
320
+ if activation_fn == "gelu":
321
+ act_fn = GELU(dim, inner_dim)
322
+ elif activation_fn == "geglu":
323
+ act_fn = GEGLU(dim, inner_dim)
324
+ elif activation_fn == "geglu-approximate":
325
+ act_fn = ApproximateGELU(dim, inner_dim)
326
+
327
+ self.net = nn.ModuleList([])
328
+ # project in
329
+ self.net.append(act_fn)
330
+ # project dropout
331
+ self.net.append(nn.Dropout(dropout))
332
+ # project out
333
+ self.net.append(nn.Linear(inner_dim, dim_out))
334
+
335
+ def forward(self, hidden_states):
336
+ for module in self.net:
337
+ hidden_states = module(hidden_states)
338
+ return hidden_states
339
+
340
+
341
+ class GELU(nn.Module):
342
+ r"""
343
+ GELU activation function
344
+ """
345
+
346
+ def __init__(self, dim_in: int, dim_out: int):
347
+ super().__init__()
348
+ self.proj = nn.Linear(dim_in, dim_out)
349
+
350
+ def gelu(self, gate):
351
+ if gate.device.type != "mps":
352
+ return F.gelu(gate)
353
+ # mps: gelu is not implemented for float16
354
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
355
+
356
+ def forward(self, hidden_states):
357
+ hidden_states = self.proj(hidden_states)
358
+ hidden_states = self.gelu(hidden_states)
359
+ return hidden_states
360
+
361
+
362
+ # feedforward
363
+ class GEGLU(nn.Module):
364
+ r"""
365
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
366
+
367
+ Parameters:
368
+ dim_in (`int`): The number of channels in the input.
369
+ dim_out (`int`): The number of channels in the output.
370
+ """
371
+
372
+ def __init__(self, dim_in: int, dim_out: int):
373
+ super().__init__()
374
+ self.proj = nn.Linear(dim_in, dim_out * 2)
375
+
376
+ def gelu(self, gate):
377
+ if gate.device.type != "mps":
378
+ return F.gelu(gate)
379
+ # mps: gelu is not implemented for float16
380
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
381
+
382
+ def forward(self, hidden_states):
383
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
384
+ return hidden_states * self.gelu(gate)
385
+
386
+
387
+ class ApproximateGELU(nn.Module):
388
+ """
389
+ The approximate form of Gaussian Error Linear Unit (GELU)
390
+
391
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
392
+ """
393
+
394
+ def __init__(self, dim_in: int, dim_out: int):
395
+ super().__init__()
396
+ self.proj = nn.Linear(dim_in, dim_out)
397
+
398
+ def forward(self, x):
399
+ x = self.proj(x)
400
+ return x * torch.sigmoid(1.702 * x)
401
+
402
+
403
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
404
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
405
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
406
+ freqs = torch.outer(t, freqs)
407
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
408
+ return freqs_cis
409
+
410
+
411
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
412
+ ndim = x.ndim
413
+ assert 0 <= 1 < ndim
414
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
415
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
416
+ return freqs_cis.view(*shape)
417
+
418
+
419
+ def apply_rotary_emb(
420
+ xq: torch.Tensor,
421
+ xk: torch.Tensor,
422
+ freqs_cis: torch.Tensor,
423
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
424
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
425
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
426
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
427
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
428
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
429
+ return xq_out.type_as(xq), xk_out.type_as(xk)
extern/video_depth_anything/motion_module/motion_module.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
2
+ # SPDX-License-Identifier: Apache-2.0 license
3
+ #
4
+ # This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
5
+ # Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
11
+
12
+ from einops import rearrange, repeat
13
+ import math
14
+
15
+ try:
16
+ import xformers
17
+ import xformers.ops
18
+
19
+ XFORMERS_AVAILABLE = True
20
+ except ImportError:
21
+ print("xFormers not available")
22
+ XFORMERS_AVAILABLE = False
23
+
24
+
25
+ def zero_module(module):
26
+ # Zero out the parameters of a module and return it.
27
+ for p in module.parameters():
28
+ p.detach().zero_()
29
+ return module
30
+
31
+
32
+ class TemporalModule(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ num_attention_heads = 8,
37
+ num_transformer_block = 2,
38
+ num_attention_blocks = 2,
39
+ norm_num_groups = 32,
40
+ temporal_max_len = 32,
41
+ zero_initialize = True,
42
+ pos_embedding_type = "ape",
43
+ ):
44
+ super().__init__()
45
+
46
+ self.temporal_transformer = TemporalTransformer3DModel(
47
+ in_channels=in_channels,
48
+ num_attention_heads=num_attention_heads,
49
+ attention_head_dim=in_channels // num_attention_heads,
50
+ num_layers=num_transformer_block,
51
+ num_attention_blocks=num_attention_blocks,
52
+ norm_num_groups=norm_num_groups,
53
+ temporal_max_len=temporal_max_len,
54
+ pos_embedding_type=pos_embedding_type,
55
+ )
56
+
57
+ if zero_initialize:
58
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
59
+
60
+ def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
61
+ hidden_states = input_tensor
62
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
63
+
64
+ output = hidden_states
65
+ return output
66
+
67
+
68
+ class TemporalTransformer3DModel(nn.Module):
69
+ def __init__(
70
+ self,
71
+ in_channels,
72
+ num_attention_heads,
73
+ attention_head_dim,
74
+ num_layers,
75
+ num_attention_blocks = 2,
76
+ norm_num_groups = 32,
77
+ temporal_max_len = 32,
78
+ pos_embedding_type = "ape",
79
+ ):
80
+ super().__init__()
81
+
82
+ inner_dim = num_attention_heads * attention_head_dim
83
+
84
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85
+ self.proj_in = nn.Linear(in_channels, inner_dim)
86
+
87
+ self.transformer_blocks = nn.ModuleList(
88
+ [
89
+ TemporalTransformerBlock(
90
+ dim=inner_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ attention_head_dim=attention_head_dim,
93
+ num_attention_blocks=num_attention_blocks,
94
+ temporal_max_len=temporal_max_len,
95
+ pos_embedding_type=pos_embedding_type,
96
+ )
97
+ for d in range(num_layers)
98
+ ]
99
+ )
100
+ self.proj_out = nn.Linear(inner_dim, in_channels)
101
+
102
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
103
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
104
+ video_length = hidden_states.shape[2]
105
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
106
+
107
+ batch, channel, height, width = hidden_states.shape
108
+ residual = hidden_states
109
+
110
+ hidden_states = self.norm(hidden_states)
111
+ inner_dim = hidden_states.shape[1]
112
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
113
+ hidden_states = self.proj_in(hidden_states)
114
+
115
+ # Transformer Blocks
116
+ for block in self.transformer_blocks:
117
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
118
+
119
+ # output
120
+ hidden_states = self.proj_out(hidden_states)
121
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
122
+
123
+ output = hidden_states + residual
124
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
125
+
126
+ return output
127
+
128
+
129
+ class TemporalTransformerBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim,
133
+ num_attention_heads,
134
+ attention_head_dim,
135
+ num_attention_blocks = 2,
136
+ temporal_max_len = 32,
137
+ pos_embedding_type = "ape",
138
+ ):
139
+ super().__init__()
140
+
141
+ self.attention_blocks = nn.ModuleList(
142
+ [
143
+ TemporalAttention(
144
+ query_dim=dim,
145
+ heads=num_attention_heads,
146
+ dim_head=attention_head_dim,
147
+ temporal_max_len=temporal_max_len,
148
+ pos_embedding_type=pos_embedding_type,
149
+ )
150
+ for i in range(num_attention_blocks)
151
+ ]
152
+ )
153
+ self.norms = nn.ModuleList(
154
+ [
155
+ nn.LayerNorm(dim)
156
+ for i in range(num_attention_blocks)
157
+ ]
158
+ )
159
+
160
+ self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
161
+ self.ff_norm = nn.LayerNorm(dim)
162
+
163
+
164
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
165
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
166
+ norm_hidden_states = norm(hidden_states)
167
+ hidden_states = attention_block(
168
+ norm_hidden_states,
169
+ encoder_hidden_states=encoder_hidden_states,
170
+ video_length=video_length,
171
+ attention_mask=attention_mask,
172
+ ) + hidden_states
173
+
174
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
175
+
176
+ output = hidden_states
177
+ return output
178
+
179
+
180
+ class PositionalEncoding(nn.Module):
181
+ def __init__(
182
+ self,
183
+ d_model,
184
+ dropout = 0.,
185
+ max_len = 32
186
+ ):
187
+ super().__init__()
188
+ self.dropout = nn.Dropout(p=dropout)
189
+ position = torch.arange(max_len).unsqueeze(1)
190
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
191
+ pe = torch.zeros(1, max_len, d_model)
192
+ pe[0, :, 0::2] = torch.sin(position * div_term)
193
+ pe[0, :, 1::2] = torch.cos(position * div_term)
194
+ self.register_buffer('pe', pe)
195
+
196
+ def forward(self, x):
197
+ x = x + self.pe[:, :x.size(1)].to(x.dtype)
198
+ return self.dropout(x)
199
+
200
+ class TemporalAttention(CrossAttention):
201
+ def __init__(
202
+ self,
203
+ temporal_max_len = 32,
204
+ pos_embedding_type = "ape",
205
+ *args, **kwargs
206
+ ):
207
+ super().__init__(*args, **kwargs)
208
+
209
+ self.pos_embedding_type = pos_embedding_type
210
+ self._use_memory_efficient_attention_xformers = True
211
+
212
+ self.pos_encoder = None
213
+ self.freqs_cis = None
214
+ if self.pos_embedding_type == "ape":
215
+ self.pos_encoder = PositionalEncoding(
216
+ kwargs["query_dim"],
217
+ dropout=0.,
218
+ max_len=temporal_max_len
219
+ )
220
+
221
+ elif self.pos_embedding_type == "rope":
222
+ self.freqs_cis = precompute_freqs_cis(
223
+ kwargs["query_dim"],
224
+ temporal_max_len
225
+ )
226
+
227
+ else:
228
+ raise NotImplementedError
229
+
230
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
231
+ d = hidden_states.shape[1]
232
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
233
+
234
+ if self.pos_encoder is not None:
235
+ hidden_states = self.pos_encoder(hidden_states)
236
+
237
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
238
+
239
+ if self.group_norm is not None:
240
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
241
+
242
+ query = self.to_q(hidden_states)
243
+ dim = query.shape[-1]
244
+
245
+ if self.added_kv_proj_dim is not None:
246
+ raise NotImplementedError
247
+
248
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
249
+ key = self.to_k(encoder_hidden_states)
250
+ value = self.to_v(encoder_hidden_states)
251
+
252
+ if self.freqs_cis is not None:
253
+ seq_len = query.shape[1]
254
+ freqs_cis = self.freqs_cis[:seq_len].to(query.device)
255
+ query, key = apply_rotary_emb(query, key, freqs_cis)
256
+
257
+ if attention_mask is not None:
258
+ if attention_mask.shape[-1] != query.shape[1]:
259
+ target_length = query.shape[1]
260
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
261
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
262
+
263
+
264
+ use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers
265
+ if use_memory_efficient and (dim // self.heads) % 8 != 0:
266
+ # print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
267
+ use_memory_efficient = False
268
+
269
+ # attention, what we cannot get enough of
270
+ if use_memory_efficient:
271
+ query = self.reshape_heads_to_4d(query)
272
+ key = self.reshape_heads_to_4d(key)
273
+ value = self.reshape_heads_to_4d(value)
274
+
275
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
276
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
277
+ hidden_states = hidden_states.to(query.dtype)
278
+ else:
279
+ query = self.reshape_heads_to_batch_dim(query)
280
+ key = self.reshape_heads_to_batch_dim(key)
281
+ value = self.reshape_heads_to_batch_dim(value)
282
+
283
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
284
+ hidden_states = self._attention(query, key, value, attention_mask)
285
+ else:
286
+ raise NotImplementedError
287
+ # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
288
+
289
+ # linear proj
290
+ hidden_states = self.to_out[0](hidden_states)
291
+
292
+ # dropout
293
+ hidden_states = self.to_out[1](hidden_states)
294
+
295
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
296
+
297
+ return hidden_states
extern/video_depth_anything/util/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
extern/video_depth_anything/util/__pycache__/transform.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
extern/video_depth_anything/util/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.66 kB). View file
 
extern/video_depth_anything/util/blocks.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(
21
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
22
+ )
23
+ scratch.layer2_rn = nn.Conv2d(
24
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
25
+ )
26
+ scratch.layer3_rn = nn.Conv2d(
27
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
28
+ )
29
+ if len(in_shape) >= 4:
30
+ scratch.layer4_rn = nn.Conv2d(
31
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
32
+ )
33
+
34
+ return scratch
35
+
36
+
37
+ class ResidualConvUnit(nn.Module):
38
+ """Residual convolution module."""
39
+
40
+ def __init__(self, features, activation, bn):
41
+ """Init.
42
+
43
+ Args:
44
+ features (int): number of features
45
+ """
46
+ super().__init__()
47
+
48
+ self.bn = bn
49
+
50
+ self.groups = 1
51
+
52
+ self.conv1 = nn.Conv2d(
53
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
54
+ )
55
+
56
+ self.conv2 = nn.Conv2d(
57
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
58
+ )
59
+
60
+ if self.bn is True:
61
+ self.bn1 = nn.BatchNorm2d(features)
62
+ self.bn2 = nn.BatchNorm2d(features)
63
+
64
+ self.activation = activation
65
+
66
+ self.skip_add = nn.quantized.FloatFunctional()
67
+
68
+ def forward(self, x):
69
+ """Forward pass.
70
+
71
+ Args:
72
+ x (tensor): input
73
+
74
+ Returns:
75
+ tensor: output
76
+ """
77
+
78
+ out = self.activation(x)
79
+ out = self.conv1(out)
80
+ if self.bn is True:
81
+ out = self.bn1(out)
82
+
83
+ out = self.activation(out)
84
+ out = self.conv2(out)
85
+ if self.bn is True:
86
+ out = self.bn2(out)
87
+
88
+ if self.groups > 1:
89
+ out = self.conv_merge(out)
90
+
91
+ return self.skip_add.add(out, x)
92
+
93
+
94
+ class FeatureFusionBlock(nn.Module):
95
+ """Feature fusion block."""
96
+
97
+ def __init__(
98
+ self,
99
+ features,
100
+ activation,
101
+ deconv=False,
102
+ bn=False,
103
+ expand=False,
104
+ align_corners=True,
105
+ size=None,
106
+ ):
107
+ """Init.
108
+
109
+ Args:
110
+ features (int): number of features
111
+ """
112
+ super().__init__()
113
+
114
+ self.deconv = deconv
115
+ self.align_corners = align_corners
116
+
117
+ self.groups = 1
118
+
119
+ self.expand = expand
120
+ out_features = features
121
+ if self.expand is True:
122
+ out_features = features // 2
123
+
124
+ self.out_conv = nn.Conv2d(
125
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
126
+ )
127
+
128
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
129
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
130
+
131
+ self.skip_add = nn.quantized.FloatFunctional()
132
+
133
+ self.size = size
134
+
135
+ def forward(self, *xs, size=None):
136
+ """Forward pass.
137
+
138
+ Returns:
139
+ tensor: output
140
+ """
141
+ output = xs[0]
142
+
143
+ if len(xs) == 2:
144
+ res = self.resConfUnit1(xs[1])
145
+ output = self.skip_add.add(output, res)
146
+
147
+ output = self.resConfUnit2(output)
148
+
149
+ if (size is None) and (self.size is None):
150
+ modifier = {"scale_factor": 2}
151
+ elif size is None:
152
+ modifier = {"size": self.size}
153
+ else:
154
+ modifier = {"size": size}
155
+
156
+ output = nn.functional.interpolate(
157
+ output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
158
+ )
159
+
160
+ output = self.out_conv(output)
161
+
162
+ return output
extern/video_depth_anything/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample
extern/video_depth_anything/util/util.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ def compute_scale_and_shift(prediction, target, mask, scale_only=False):
17
+ if scale_only:
18
+ return compute_scale(prediction, target, mask), 0
19
+ else:
20
+ return compute_scale_and_shift_full(prediction, target, mask)
21
+
22
+
23
+ def compute_scale(prediction, target, mask):
24
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
25
+ prediction = prediction.astype(np.float32)
26
+ target = target.astype(np.float32)
27
+ mask = mask.astype(np.float32)
28
+
29
+ a_00 = np.sum(mask * prediction * prediction)
30
+ a_01 = np.sum(mask * prediction)
31
+ a_11 = np.sum(mask)
32
+
33
+ # right hand side: b = [b_0, b_1]
34
+ b_0 = np.sum(mask * prediction * target)
35
+
36
+ x_0 = b_0 / (a_00 + 1e-6)
37
+
38
+ return x_0
39
+
40
+ def compute_scale_and_shift_full(prediction, target, mask):
41
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
42
+ prediction = prediction.astype(np.float32)
43
+ target = target.astype(np.float32)
44
+ mask = mask.astype(np.float32)
45
+
46
+ a_00 = np.sum(mask * prediction * prediction)
47
+ a_01 = np.sum(mask * prediction)
48
+ a_11 = np.sum(mask)
49
+
50
+ b_0 = np.sum(mask * prediction * target)
51
+ b_1 = np.sum(mask * target)
52
+
53
+ x_0 = 1
54
+ x_1 = 0
55
+
56
+ det = a_00 * a_11 - a_01 * a_01
57
+
58
+ if det != 0:
59
+ x_0 = (a_11 * b_0 - a_01 * b_1) / det
60
+ x_1 = (-a_01 * b_0 + a_00 * b_1) / det
61
+
62
+ return x_0, x_1
63
+
64
+
65
+ def get_interpolate_frames(frame_list_pre, frame_list_post):
66
+ assert len(frame_list_pre) == len(frame_list_post)
67
+ min_w = 0.0
68
+ max_w = 1.0
69
+ step = (max_w - min_w) / (len(frame_list_pre)-1)
70
+ post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
71
+ interpolated_frames = []
72
+ for i in range(len(frame_list_pre)):
73
+ interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
74
+ return interpolated_frames
extern/video_depth_anything/vdademo.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import numpy as np
16
+ import os
17
+ import torch
18
+ from extern.video_depth_anything.video_depth import VideoDepthAnything
19
+
20
+ class VDADemo:
21
+ def __init__(
22
+ self,
23
+ pre_train_path: str,
24
+ encoder: str = "vitl",
25
+ device: str = "cuda:0",
26
+ ):
27
+
28
+ model_configs = {
29
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
30
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
31
+ }
32
+
33
+ self.video_depth_anything = VideoDepthAnything(**model_configs[encoder])
34
+ self.video_depth_anything.load_state_dict(torch.load(pre_train_path, map_location='cpu'), strict=True)
35
+ self.video_depth_anything = self.video_depth_anything.to(device).eval()
36
+ self.device = device
37
+
38
+ def infer(
39
+ self,
40
+ frames,
41
+ near,
42
+ far,
43
+ input_size = 518,
44
+ target_fps = -1,
45
+ ):
46
+ if frames.max() < 2.:
47
+ frames = frames*255.
48
+
49
+ with torch.inference_mode():
50
+ depths, fps = self.video_depth_anything.infer_video_depth(frames, target_fps, input_size, self.device)
51
+
52
+ depths = torch.from_numpy(depths).unsqueeze(1) # 49 576 1024 ->
53
+ depths[depths < 1e-5] = 1e-5
54
+ depths = 10000. / depths
55
+ depths = depths.clip(near, far)
56
+
57
+
58
+ return depths
59
+
60
+
61
+
62
+
63
+
extern/video_depth_anything/video_depth.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.nn as nn
17
+ from torchvision.transforms import Compose
18
+ import cv2
19
+ from tqdm import tqdm
20
+ import numpy as np
21
+ import gc
22
+
23
+ from extern.video_depth_anything.dinov2 import DINOv2
24
+ from extern.video_depth_anything.dpt_temporal import DPTHeadTemporal
25
+ from extern.video_depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
26
+
27
+ from extern.video_depth_anything.util.util import compute_scale_and_shift, get_interpolate_frames
28
+
29
+ # infer settings, do not change
30
+ INFER_LEN = 32
31
+ OVERLAP = 10
32
+ KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
33
+ INTERP_LEN = 8
34
+
35
+ class VideoDepthAnything(nn.Module):
36
+ def __init__(
37
+ self,
38
+ encoder='vitl',
39
+ features=256,
40
+ out_channels=[256, 512, 1024, 1024],
41
+ use_bn=False,
42
+ use_clstoken=False,
43
+ num_frames=32,
44
+ pe='ape'
45
+ ):
46
+ super(VideoDepthAnything, self).__init__()
47
+
48
+ self.intermediate_layer_idx = {
49
+ 'vits': [2, 5, 8, 11],
50
+ 'vitl': [4, 11, 17, 23]
51
+ }
52
+
53
+ self.encoder = encoder
54
+ self.pretrained = DINOv2(model_name=encoder)
55
+
56
+ self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
57
+
58
+ def forward(self, x):
59
+ B, T, C, H, W = x.shape
60
+ patch_h, patch_w = H // 14, W // 14
61
+ features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
62
+ depth = self.head(features, patch_h, patch_w, T)
63
+ depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
64
+ depth = F.relu(depth)
65
+ return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
66
+
67
+ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda'):
68
+ frame_height, frame_width = frames[0].shape[:2]
69
+ ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
70
+ if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation
71
+ input_size = int(input_size * 1.777 / ratio)
72
+ input_size = round(input_size / 14) * 14
73
+
74
+ transform = Compose([
75
+ Resize(
76
+ width=input_size,
77
+ height=input_size,
78
+ resize_target=False,
79
+ keep_aspect_ratio=True,
80
+ ensure_multiple_of=14,
81
+ resize_method='lower_bound',
82
+ image_interpolation_method=cv2.INTER_CUBIC,
83
+ ),
84
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
85
+ PrepareForNet(),
86
+ ])
87
+
88
+ frame_list = [frames[i] for i in range(frames.shape[0])]
89
+ frame_step = INFER_LEN - OVERLAP
90
+ org_video_len = len(frame_list)
91
+ append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
92
+ frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
93
+
94
+ depth_list = []
95
+ pre_input = None
96
+ for frame_id in tqdm(range(0, org_video_len, frame_step)):
97
+ cur_list = []
98
+ for i in range(INFER_LEN):
99
+ cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
100
+ cur_input = torch.cat(cur_list, dim=1).to(device)
101
+ if pre_input is not None:
102
+ cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
103
+
104
+ with torch.no_grad():
105
+ depth = self.forward(cur_input) # depth shape: [1, T, H, W]
106
+
107
+ depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
108
+ depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
109
+
110
+ pre_input = cur_input
111
+
112
+ del frame_list
113
+ gc.collect()
114
+
115
+ depth_list_aligned = []
116
+ ref_align = []
117
+ align_len = OVERLAP - INTERP_LEN
118
+ kf_align_list = KEYFRAMES[:align_len]
119
+
120
+ for frame_id in range(0, len(depth_list), INFER_LEN):
121
+ if len(depth_list_aligned) == 0:
122
+ depth_list_aligned += depth_list[:INFER_LEN]
123
+ for kf_id in kf_align_list:
124
+ ref_align.append(depth_list[frame_id+kf_id])
125
+ else:
126
+ curr_align = []
127
+ for i in range(len(kf_align_list)):
128
+ curr_align.append(depth_list[frame_id+i])
129
+ scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
130
+ np.concatenate(ref_align),
131
+ np.concatenate(np.ones_like(ref_align)==1))
132
+
133
+ pre_depth_list = depth_list_aligned[-INTERP_LEN:]
134
+ post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
135
+ for i in range(len(post_depth_list)):
136
+ post_depth_list[i] = post_depth_list[i] * scale + shift
137
+ post_depth_list[i][post_depth_list[i]<0] = 0
138
+ depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
139
+
140
+ for i in range(OVERLAP, INFER_LEN):
141
+ new_depth = depth_list[frame_id+i] * scale + shift
142
+ new_depth[new_depth<0] = 0
143
+ depth_list_aligned.append(new_depth)
144
+
145
+ ref_align = ref_align[:1]
146
+ for kf_id in kf_align_list[1:]:
147
+ new_depth = depth_list[frame_id+kf_id] * scale + shift
148
+ new_depth[new_depth<0] = 0
149
+ ref_align.append(new_depth)
150
+
151
+ depth_list = depth_list_aligned
152
+
153
+ return np.stack(depth_list[:org_video_len], axis=0), target_fps
154
+