zixinz commited on
Commit
5a0778e
·
1 Parent(s): f1483c5

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +56 -0
  2. code_depth/LICENSE +201 -0
  3. code_depth/README.md +120 -0
  4. code_depth/app.py +152 -0
  5. code_depth/assets/example_videos/Tokyo-Walk_rgb.mp4 +3 -0
  6. code_depth/assets/example_videos/davis_rollercoaster.mp4 +3 -0
  7. code_depth/assets/teaser_video_v2.png +3 -0
  8. code_depth/benchmark/README.md +34 -0
  9. code_depth/benchmark/__init__.py +0 -0
  10. code_depth/benchmark/dataset_extract/dataset_extract_bonn.py +86 -0
  11. code_depth/benchmark/dataset_extract/dataset_extract_kitti.py +84 -0
  12. code_depth/benchmark/dataset_extract/dataset_extract_nyuv2.py +76 -0
  13. code_depth/benchmark/dataset_extract/dataset_extract_scannet.py +124 -0
  14. code_depth/benchmark/dataset_extract/dataset_extract_sintel.py +110 -0
  15. code_depth/benchmark/dataset_extract/eval_utils.py +140 -0
  16. code_depth/benchmark/eval/eval.py +265 -0
  17. code_depth/benchmark/eval/eval.sh +30 -0
  18. code_depth/benchmark/eval/eval_500.sh +30 -0
  19. code_depth/benchmark/eval/eval_tae.py +295 -0
  20. code_depth/benchmark/eval/eval_tae.sh +18 -0
  21. code_depth/benchmark/eval/metric.py +117 -0
  22. code_depth/benchmark/infer/infer.py +65 -0
  23. code_depth/get_weights.sh +6 -0
  24. code_depth/large_files.txt +2 -0
  25. code_depth/requirements.txt +14 -0
  26. code_depth/run.py +81 -0
  27. code_depth/run_images_rord.py +112 -0
  28. code_depth/run_single_image.py +69 -0
  29. code_depth/utils/dc_utils.py +86 -0
  30. code_depth/utils/util.py +74 -0
  31. code_depth/video_depth_anything/dinov2.py +415 -0
  32. code_depth/video_depth_anything/dinov2_layers/__init__.py +11 -0
  33. code_depth/video_depth_anything/dinov2_layers/attention.py +83 -0
  34. code_depth/video_depth_anything/dinov2_layers/block.py +252 -0
  35. code_depth/video_depth_anything/dinov2_layers/drop_path.py +35 -0
  36. code_depth/video_depth_anything/dinov2_layers/layer_scale.py +28 -0
  37. code_depth/video_depth_anything/dinov2_layers/mlp.py +41 -0
  38. code_depth/video_depth_anything/dinov2_layers/patch_embed.py +89 -0
  39. code_depth/video_depth_anything/dinov2_layers/swiglu_ffn.py +63 -0
  40. code_depth/video_depth_anything/dpt.py +160 -0
  41. code_depth/video_depth_anything/dpt_temporal.py +114 -0
  42. code_depth/video_depth_anything/motion_module/attention.py +429 -0
  43. code_depth/video_depth_anything/motion_module/motion_module.py +297 -0
  44. code_depth/video_depth_anything/util/blocks.py +162 -0
  45. code_depth/video_depth_anything/util/transform.py +158 -0
  46. code_depth/video_depth_anything/video_depth.py +156 -0
  47. code_edit/.gradio/certificate.pem +31 -0
  48. code_edit/Flux_fill_d2i.py +53 -0
  49. code_edit/Flux_fill_infer_depth.py +64 -0
  50. code_edit/README.md +93 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import subprocess
4
+ import gradio as gr
5
+ import spaces
6
+ import torch
7
+
8
+ # ---------- 权重下载:强制在 code_depth 下执行你的脚本 ----------
9
+ BASE_DIR = pathlib.Path(__file__).resolve().parent
10
+ SCRIPT_DIR = BASE_DIR / "code_depth"
11
+ GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
12
+
13
+ def ensure_executable(path: pathlib.Path):
14
+ if not path.exists():
15
+ raise FileNotFoundError(f"Download script not found: {path}")
16
+ os.chmod(path, os.stat(path).st_mode | 0o111)
17
+
18
+ def ensure_weights() -> str:
19
+ """
20
+ 在 code_depth 目录下运行 get_weights.sh。
21
+ 该脚本会在 code_depth/ 下创建 checkpoints/ 并下载权重。
22
+ 返回绝对路径:<repo_root>/code_depth/checkpoints
23
+ """
24
+ ensure_executable(GET_WEIGHTS_SH)
25
+ # 你脚本的工作目录需要是 code_depth
26
+ subprocess.run(
27
+ ["bash", str(GET_WEIGHTS_SH)],
28
+ check=True,
29
+ cwd=str(SCRIPT_DIR),
30
+ env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
31
+ )
32
+ ckpt_dir = SCRIPT_DIR / "checkpoints"
33
+ return str(ckpt_dir)
34
+
35
+ # 启动时先拉权重(不开 Persistent Storage 时,重建环境会清空;重启后会自动再拉一次)
36
+ try:
37
+ CKPT_DIR = ensure_weights()
38
+ print(f"✅ Weights ready in: {CKPT_DIR}")
39
+ except Exception as e:
40
+ print(f"⚠️ Failed to prepare weights: {e}")
41
+ CKPT_DIR = str(SCRIPT_DIR / "checkpoints") # 仍然给个路径,后续可检查是否存在
42
+
43
+ # ---------- Gradio 推理函数 ----------
44
+ @spaces.GPU
45
+ def greet(n: float):
46
+ # 在 GPU worker 里拿 device
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ zero = torch.tensor([0.0], device=device)
49
+ # 仅示例输出,你可以在这里用 CKPT_DIR 加载你的模型
50
+ print(f"Device in greet(): {device}")
51
+ print(f"Using checkpoints from: {CKPT_DIR}")
52
+ return f"Hello {(zero + n).item()} Tensor (device={device})"
53
+
54
+ demo = gr.Interface(fn=greet, inputs=gr.Number(label="n"), outputs=gr.Text())
55
+ if __name__ == "__main__":
56
+ demo.launch(server_name="0.0.0.0", server_port=7860)
code_depth/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
code_depth/README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>Video Depth Anything</h1>
3
+
4
+ [**Sili Chen**](https://github.com/SiliChen321) · [**Hengkai Guo**](https://guohengkai.github.io/)<sup>&dagger;</sup> · [**Shengnan Zhu**](https://github.com/Shengnan-Zhu) · [**Feihu Zhang**](https://github.com/zhizunhu)
5
+ <br>
6
+ [**Zilong Huang**](http://speedinghzl.github.io/) · [**Jiashi Feng**](https://scholar.google.com.sg/citations?user=Q8iay0gAAAAJ&hl=en) · [**Bingyi Kang**](https://bingykang.github.io/)<sup>&dagger;</sup>
7
+ <br>
8
+ ByteDance
9
+ <br>
10
+ &dagger;Corresponding author
11
+
12
+ <a href="https://arxiv.org/abs/2501.12375"><img src='https://img.shields.io/badge/arXiv-Video Depth Anything-red' alt='Paper PDF'></a>
13
+ <a href='https://videodepthanything.github.io'><img src='https://img.shields.io/badge/Project_Page-Video Depth Anything-green' alt='Project Page'></a>
14
+ <a href='https://huggingface.co/spaces/depth-anything/Video-Depth-Anything'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>
15
+ </div>
16
+
17
+ </div>
18
+
19
+ This work presents **Video Depth Anything** based on [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2), which can be applied to arbitrarily long videos without compromising quality, consistency, or generalization ability. Compared with other diffusion-based models, it enjoys faster inference speed, fewer parameters, and higher consistent depth accuracy.
20
+
21
+ ![teaser](assets/teaser_video_v2.png)
22
+
23
+ ## News
24
+ - **2025-03-11:** Add full dataset inference and evaluation scripts.
25
+ - **2025-02-08:** Enable autocast inference. Support grayscale video, NPZ and EXR output formats.
26
+ - **2025-01-21:** Paper, project page, code, models, and demo are all released.
27
+
28
+
29
+ ## Release Notes
30
+ - **2025-02-08:** 🚀🚀🚀 Inference speed and memory usage improvement
31
+ <table>
32
+ <thead>
33
+ <tr>
34
+ <th rowspan="2" style="text-align: center;">Model</th>
35
+ <th colspan="2">Latency (ms)</th>
36
+ <th colspan="2">GPU VRAM (GB)</th>
37
+ </tr>
38
+ <tr>
39
+ <th>FP32</th>
40
+ <th>FP16</th>
41
+ <th>FP32</th>
42
+ <th>FP16</th>
43
+ </tr>
44
+ </thead>
45
+ <tbody>
46
+ <tr>
47
+ <td>Video-Depth-Anything-V2-Small</td>
48
+ <td>9.1</td>
49
+ <td><strong>7.5</strong></td>
50
+ <td>7.3</td>
51
+ <td><strong>6.8</strong></td>
52
+ </tr>
53
+ <tr>
54
+ <td>Video-Depth-Anything-V2-Large</td>
55
+ <td>67</td>
56
+ <td><strong>14</strong></td>
57
+ <td>26.7</td>
58
+ <td><strong>23.6</strong></td>
59
+ </tbody>
60
+ </table>
61
+
62
+ The Latency and GPU VRAM results are obtained on a single A100 GPU with input of shape 1 x 32 x 518 × 518.
63
+
64
+ ## Pre-trained Models
65
+ We provide **two models** of varying scales for robust and consistent video depth estimation:
66
+
67
+ | Model | Params | Checkpoint |
68
+ |:-|-:|:-:|
69
+ | Video-Depth-Anything-V2-Small | 28.4M | [Download](https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth?download=true) |
70
+ | Video-Depth-Anything-V2-Large | 381.8M | [Download](https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth?download=true) |
71
+
72
+ ## Usage
73
+
74
+ ### Preparation
75
+
76
+ ```bash
77
+ git clone https://github.com/DepthAnything/Video-Depth-Anything
78
+ cd Video-Depth-Anything
79
+ pip install -r requirements.txt
80
+ ```
81
+
82
+ Download the checkpoints listed [here](#pre-trained-models) and put them under the `checkpoints` directory.
83
+ ```bash
84
+ bash get_weights.sh
85
+ ```
86
+
87
+ ### Inference a video
88
+ ```bash
89
+ python3 run.py --input_video ./assets/example_videos/davis_rollercoaster.mp4 --output_dir ./outputs --encoder vitl
90
+ ```
91
+
92
+ Options:
93
+ - `--input_video`: path of input video
94
+ - `--output_dir`: path to save the output results
95
+ - `--input_size` (optional): By default, we use input size `518` for model inference.
96
+ - `--max_res` (optional): By default, we use maximum resolution `1280` for model inference.
97
+ - `--encoder` (optional): `vits` for Video-Depth-Anything-V2-Small, `vitl` for Video-Depth-Anything-V2-Large.
98
+ - `--max_len` (optional): maximum length of the input video, `-1` means no limit
99
+ - `--target_fps` (optional): target fps of the input video, `-1` means the original fps
100
+ - `--fp32` (optional): Use `fp32` precision for inference. By default, we use `fp16`.
101
+ - `--grayscale` (optional): Save the grayscale depth map, without applying color palette.
102
+ - `--save_npz` (optional): Save the depth map in `npz` format.
103
+ - `--save_exr` (optional): Save the depth map in `exr` format.
104
+
105
+ ## Citation
106
+
107
+ If you find this project useful, please consider citing:
108
+
109
+ ```bibtex
110
+ @article{video_depth_anything,
111
+ title={Video Depth Anything: Consistent Depth Estimation for Super-Long Videos},
112
+ author={Chen, Sili and Guo, Hengkai and Zhu, Shengnan and Zhang, Feihu and Huang, Zilong and Feng, Jiashi and Kang, Bingyi}
113
+ journal={arXiv:2501.12375},
114
+ year={2025}
115
+ }
116
+ ```
117
+
118
+
119
+ ## LICENSE
120
+ Video-Depth-Anything-Small model is under the Apache-2.0 license. Video-Depth-Anything-Large model is under the CC-BY-NC-4.0 license. For business cooperation, please send an email to Hengkai Guo at guohengkaighk@gmail.com.
code_depth/app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import gradio as gr
15
+
16
+ import numpy as np
17
+ import os
18
+ import torch
19
+
20
+ from video_depth_anything.video_depth import VideoDepthAnything
21
+ from utils.dc_utils import read_video_frames, save_video
22
+
23
+ examples = [
24
+ ['assets/example_videos/davis_rollercoaster.mp4', -1, -1, 1280],
25
+ ]
26
+
27
+ model_configs = {
28
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
29
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
30
+ }
31
+
32
+ encoder='vitl'
33
+
34
+ video_depth_anything = VideoDepthAnything(**model_configs[encoder])
35
+ video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{encoder}.pth', map_location='cpu'), strict=True)
36
+ video_depth_anything = video_depth_anything.to('cuda').eval()
37
+
38
+
39
+ def infer_video_depth(
40
+ input_video: str,
41
+ max_len: int = -1,
42
+ target_fps: int = -1,
43
+ max_res: int = 1280,
44
+ output_dir: str = './outputs',
45
+ input_size: int = 518,
46
+ ):
47
+ frames, target_fps = read_video_frames(input_video, max_len, target_fps, max_res)
48
+ depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=input_size, device='cuda')
49
+
50
+ video_name = os.path.basename(input_video)
51
+ if not os.path.exists(output_dir):
52
+ os.makedirs(output_dir)
53
+
54
+ processed_video_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
55
+ depth_vis_path = os.path.join(output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
56
+ save_video(frames, processed_video_path, fps=fps)
57
+ save_video(depths, depth_vis_path, fps=fps, is_depths=True)
58
+
59
+ return [processed_video_path, depth_vis_path]
60
+
61
+
62
+ def construct_demo():
63
+ with gr.Blocks(analytics_enabled=False) as demo:
64
+ gr.Markdown(
65
+ f"""
66
+ blablabla
67
+ """
68
+ )
69
+
70
+ with gr.Row(equal_height=True):
71
+ with gr.Column(scale=1):
72
+ input_video = gr.Video(label="Input Video")
73
+
74
+ # with gr.Tab(label="Output"):
75
+ with gr.Column(scale=2):
76
+ with gr.Row(equal_height=True):
77
+ processed_video = gr.Video(
78
+ label="Preprocessed video",
79
+ interactive=False,
80
+ autoplay=True,
81
+ loop=True,
82
+ show_share_button=True,
83
+ scale=5,
84
+ )
85
+ depth_vis_video = gr.Video(
86
+ label="Generated Depth Video",
87
+ interactive=False,
88
+ autoplay=True,
89
+ loop=True,
90
+ show_share_button=True,
91
+ scale=5,
92
+ )
93
+
94
+ with gr.Row(equal_height=True):
95
+ with gr.Column(scale=1):
96
+ with gr.Row(equal_height=False):
97
+ with gr.Accordion("Advanced Settings", open=False):
98
+ max_len = gr.Slider(
99
+ label="max process length",
100
+ minimum=-1,
101
+ maximum=1000,
102
+ value=-1,
103
+ step=1,
104
+ )
105
+ target_fps = gr.Slider(
106
+ label="target FPS",
107
+ minimum=-1,
108
+ maximum=30,
109
+ value=15,
110
+ step=1,
111
+ )
112
+ max_res = gr.Slider(
113
+ label="max side resolution",
114
+ minimum=480,
115
+ maximum=1920,
116
+ value=1280,
117
+ step=1,
118
+ )
119
+ generate_btn = gr.Button("Generate")
120
+ with gr.Column(scale=2):
121
+ pass
122
+
123
+ gr.Examples(
124
+ examples=examples,
125
+ inputs=[
126
+ input_video,
127
+ max_len,
128
+ target_fps,
129
+ max_res
130
+ ],
131
+ outputs=[processed_video, depth_vis_video],
132
+ fn=infer_video_depth,
133
+ cache_examples="lazy",
134
+ )
135
+
136
+ generate_btn.click(
137
+ fn=infer_video_depth,
138
+ inputs=[
139
+ input_video,
140
+ max_len,
141
+ target_fps,
142
+ max_res
143
+ ],
144
+ outputs=[processed_video, depth_vis_video],
145
+ )
146
+
147
+ return demo
148
+
149
+ if __name__ == "__main__":
150
+ demo = construct_demo()
151
+ demo.queue()
152
+ demo.launch(server_name="0.0.0.0")
code_depth/assets/example_videos/Tokyo-Walk_rgb.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:097f16c33dd8c8d1d2a24d9ea31a90b76bd0ee324b958a47385183e3547a63a8
3
+ size 2251450
code_depth/assets/example_videos/davis_rollercoaster.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7268cbecd9806a1e90a416de50dc02e50b4ae01428d5971837cf679dd0c87cb8
3
+ size 1809560
code_depth/assets/teaser_video_v2.png ADDED

Git LFS Details

  • SHA256: 7ab2bf5f739de9d00adafe15ac4225143b59e208b8f79af7dc22c417c3a4584f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.8 MB
code_depth/benchmark/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BENCHMARK
2
+
3
+ ## Prepare Dataset
4
+ Download datasets from the following links:
5
+ [sintel](http://sintel.is.tue.mpg.de/) [kitti](https://www.cvlibs.net/datasets/kitti/) [bonn](https://www.ipb.uni-bonn.de/data/rgbd-dynamic-dataset/index.html) [scannet](http://www.scan-net.org/) [nyuv2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)
6
+
7
+ ```bash
8
+ pip3 install natsort
9
+ cd benchmark/dataset_extract
10
+ python3 dataset_extrtact${dataset}.py
11
+ ```
12
+ This script will extract the dataset to the `benchmark/dataset_extract/dataset` folder. It will also generate the json file for the dataset.
13
+
14
+ ## Run inference
15
+ ```bash
16
+ python3 benchmark/infer/infer.py \
17
+ --infer_path ${out_path} \
18
+ --json_file ${json_path} \
19
+ --datasets ${dataset}
20
+ ```
21
+ Options:
22
+ - `--infer_path`: path to save the output results
23
+ - `--json_file`: path to the json file for the dataset
24
+ - `--datasets`: dataset name, choose from `sintel`, `kitti`, `bonn`, `scannet`, `nyuv2`
25
+
26
+ ## Run evaluation
27
+ ```bash
28
+ ## tae
29
+ bash benchmark/eval/eval_tae.sh ${out_path} benchmark/dataset_extract/dataset
30
+ ## ~110frame like DepthCrafter
31
+ bash benchmark/eval/eval.sh ${out_path} benchmark/dataset_extract/dataset
32
+ ## ~500frame
33
+ bash benchmark/eval/eval_500.sh ${out_path} benchmark/dataset_extract/dataset
34
+ ```
code_depth/benchmark/__init__.py ADDED
File without changes
code_depth/benchmark/dataset_extract/dataset_extract_bonn.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import cv2
7
+ import csv
8
+ import json
9
+ import glob
10
+ import shutil
11
+ from natsort import natsorted
12
+
13
+ from eval_utils import gen_json, get_sorted_files, even_or_odd, copy_crop_files
14
+
15
+ def extract_bonn(
16
+ root,
17
+ depth_root,
18
+ saved_dir,
19
+ sample_len,
20
+ datatset_name,
21
+ ):
22
+ scenes_names = os.listdir(depth_root)
23
+ all_samples = []
24
+ for i, seq_name in enumerate(tqdm(scenes_names)):
25
+ # load all images
26
+ all_img_names = get_sorted_files(
27
+ root=osp.join(depth_root, seq_name, "rgb"), suffix=".png"
28
+ )
29
+ all_depth_names = get_sorted_files(
30
+ root=osp.join(depth_root, seq_name, "depth"), suffix=".png"
31
+ )
32
+
33
+ seq_len = len(all_img_names)
34
+ step = sample_len if sample_len > 0 else seq_len
35
+
36
+ for ref_idx in range(0, seq_len, step):
37
+ print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
38
+
39
+ if (ref_idx + step) <= seq_len:
40
+ ref_e = ref_idx + step
41
+ else:
42
+ continue
43
+
44
+ for idx in range(ref_idx, ref_e):
45
+ im_path = osp.join(
46
+ root, seq_name, "rgb", all_img_names[idx]
47
+ )
48
+ depth_path = osp.join(
49
+ depth_root, seq_name, "depth", all_depth_names[idx]
50
+ )
51
+ out_img_path = osp.join(
52
+ saved_dir, datatset_name,seq_name, "rgb", all_img_names[idx]
53
+ )
54
+ out_depth_path = osp.join(
55
+ saved_dir, datatset_name,seq_name, "depth", all_depth_names[idx]
56
+ )
57
+
58
+ copy_crop_files(
59
+ im_path=im_path,
60
+ depth_path=depth_path,
61
+ out_img_path=out_img_path,
62
+ out_depth_path=out_depth_path,
63
+ dataset=datatset_name,
64
+ )
65
+
66
+ # 110 frames like DepthCraft
67
+ out_json_path = osp.join(saved_dir, datatset_name, "bonn_video.json")
68
+ gen_json(
69
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
70
+ start_id=30, end_id=140, step=1, save_path=out_json_path)
71
+
72
+ #~500 frames in paper
73
+ out_json_path = osp.join(saved_dir, datatset_name, "bonn_video_500.json")
74
+ gen_json(
75
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
76
+ start_id=0, end_id=500, step=1, save_path=out_json_path)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ extract_bonn(
81
+ root="path/to/Bonn-RGBD",
82
+ depth_root="path/to/Bonn-RGBD",
83
+ saved_dir="./benchmark/datasets/",
84
+ sample_len=-1,
85
+ datatset_name="bonn",
86
+ )
code_depth/benchmark/dataset_extract/dataset_extract_kitti.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import csv
7
+ import cv2
8
+ import json
9
+ import glob
10
+ import shutil
11
+ from natsort import natsorted
12
+
13
+ from eval_utils import even_or_odd
14
+ from eval_utils import gen_json, get_sorted_files, copy_crop_files
15
+
16
+ def extract_kitti(
17
+ root,
18
+ depth_root,
19
+ sample_len=-1,
20
+ saved_dir="",
21
+ datatset_name="",
22
+ ):
23
+ scenes_names = os.listdir(depth_root)
24
+ all_samples = []
25
+ for i, seq_name in enumerate(tqdm(scenes_names)):
26
+
27
+ all_img_names = get_sorted_files(
28
+ osp.join(depth_root, seq_name, "proj_depth/groundtruth/image_02"), suffix=".png"
29
+ )
30
+
31
+ seq_len = len(all_img_names)
32
+ step = sample_len if sample_len > 0 else seq_len
33
+
34
+ for ref_idx in range(0, seq_len, step):
35
+ print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
36
+ video_imgs = []
37
+ video_depths = []
38
+
39
+ if (ref_idx + step) <= seq_len:
40
+ ref_e = ref_idx + step
41
+ else:
42
+ continue
43
+
44
+ for idx in range(ref_idx, ref_e):
45
+ im_path = osp.join(
46
+ root, seq_name[0:10], seq_name, "image_02/data", all_img_names[idx]
47
+ )
48
+ depth_path = osp.join(
49
+ depth_root, seq_name, "proj_depth/groundtruth/image_02", all_img_names[idx],
50
+ )
51
+ out_img_path = osp.join(
52
+ saved_dir, datatset_name,seq_name, "rgb", all_img_names[idx]
53
+ )
54
+ out_depth_path = osp.join(
55
+ saved_dir, datatset_name,seq_name, "depth", all_img_names[idx]
56
+ )
57
+ copy_crop_files(
58
+ im_path=im_path,
59
+ depth_path=depth_path,
60
+ out_img_path=out_img_path,
61
+ out_depth_path=out_depth_path,
62
+ dataset=datatset_name,
63
+ )
64
+
65
+ # 110 frames like DepthCraft
66
+ out_json_path = osp.join(saved_dir, datatset_name, "kitti_video.json")
67
+ gen_json(
68
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
69
+ start_id=0, end_id=110, step=1, save_path=out_json_path)
70
+
71
+ #~500 frames in paper
72
+ out_json_path = osp.join(saved_dir, datatset_name, "kitti_video_500.json")
73
+ gen_json(
74
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
75
+ start_id=0, end_id=500, step=1, save_path=out_json_path)
76
+
77
+ if __name__ == "__main__":
78
+ extract_kitti(
79
+ root="path/to/kitti",
80
+ depth_root="path/to/kitti/val",
81
+ saved_dir="./benchmark/datasets/",
82
+ sample_len=-1,
83
+ datatset_name="kitti",
84
+ )
code_depth/benchmark/dataset_extract/dataset_extract_nyuv2.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import csv
7
+ import cv2
8
+ import json
9
+ import glob
10
+ from natsort import natsorted
11
+ import shutil
12
+
13
+ from eval_utils import gen_json, get_sorted_files, copy_crop_files
14
+
15
+ def extract_nyuv2(
16
+ root,
17
+ sample_len=-1,
18
+ datatset_name="",
19
+ saved_dir="",
20
+ ):
21
+ scenes_names = os.listdir(root)
22
+ scenes_names = sorted(scenes_names)
23
+ all_samples = []
24
+ for i, seq_name in enumerate(tqdm(scenes_names)):
25
+ all_img_names = get_sorted_files(
26
+ osp.join(root, seq_name, "rgb"), suffix=".jpg")
27
+
28
+ seq_len = len(all_img_names)
29
+ step = sample_len if sample_len > 0 else seq_len
30
+
31
+ for ref_idx in range(0, seq_len, step):
32
+ print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
33
+
34
+ if (ref_idx + step) <= seq_len:
35
+ ref_e = ref_idx + step
36
+ else:
37
+ continue
38
+
39
+ for idx in range(ref_idx, ref_e):
40
+ im_path = osp.join(
41
+ root, seq_name, "rgb", all_img_names[idx]
42
+ )
43
+ depth_path = osp.join(
44
+ root, seq_name, "depth", all_img_names[idx][:-3] + "png"
45
+ )
46
+ out_img_path = osp.join(
47
+ saved_dir, datatset_name, seq_name, "rgb", all_img_names[idx]
48
+ )
49
+ out_depth_path = osp.join(
50
+ saved_dir, datatset_name, seq_name, "depth", all_img_names[idx][:-3] + "png"
51
+ )
52
+
53
+ copy_crop_files(
54
+ im_path=im_path,
55
+ depth_path=depth_path,
56
+ out_img_path=out_img_path,
57
+ out_depth_path=out_depth_path,
58
+ dataset=dataset_name,
59
+ )
60
+
61
+ #~500 frames in paper
62
+ out_json_path = osp.join(saved_dir, datatset_name, "nyuv2_video_500.json")
63
+ gen_json(
64
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
65
+ start_id=0,end_id=500,step=1,
66
+ save_path=out_json_path)
67
+
68
+ if __name__ == "__main__":
69
+ # we use matlab to extract 8 scenes from NYUv2
70
+ #--basement_0001a, bookstore_0001a, cafe_0001a, classroom_0001a, kitchen_0003, office_0004, playroom_0002, study_0002
71
+ extract_scannet(
72
+ root="path/to/nyuv2",
73
+ saved_dir="./benchmark/datasets/",
74
+ sample_len=-1,
75
+ datatset_name="nyuv2",
76
+ )
code_depth/benchmark/dataset_extract/dataset_extract_scannet.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import csv
7
+ import cv2
8
+ import json
9
+ import glob
10
+ from natsort import natsorted
11
+ import shutil
12
+
13
+ from eval_utils import gen_json, gen_json_scannet_tae, get_sorted_files, copy_crop_files
14
+
15
+ def extract_scannet(
16
+ root,
17
+ sample_len=-1,
18
+ datatset_name="",
19
+ saved_dir="",
20
+ ):
21
+ scenes_names = os.listdir(root)
22
+ scenes_names = sorted(scenes_names)[:100]
23
+ all_samples = []
24
+ for i, seq_name in enumerate(tqdm(scenes_names)):
25
+ all_img_names = get_sorted_files(
26
+ osp.join(root, seq_name, "color"), suffix=".jpg")
27
+ all_img_names = all_img_names[:510]
28
+
29
+ seq_len = len(all_img_names)
30
+ step = sample_len if sample_len > 0 else seq_len
31
+
32
+ for ref_idx in range(0, seq_len, step):
33
+ print(f"Progress: {seq_name}, {ref_idx // step + 1} / {seq_len//step}")
34
+
35
+ video_imgs = []
36
+ video_depths = []
37
+
38
+ if (ref_idx + step) <= seq_len:
39
+ ref_e = ref_idx + step
40
+ else:
41
+ continue
42
+
43
+ for idx in range(ref_idx, ref_e):
44
+ im_path = osp.join(
45
+ root, seq_name, "color", all_img_names[idx]
46
+ )
47
+ depth_path = osp.join(
48
+ root, seq_name, "depth", all_img_names[idx][:-3] + "png"
49
+ )
50
+ pose_path = osp.join(
51
+ root, seq_name, "pose", all_img_names[idx][:-3] + "txt"
52
+ )
53
+ out_img_path = osp.join(
54
+ saved_dir, datatset_name, seq_name, "color", all_img_names[idx]
55
+ )
56
+ out_depth_path = osp.join(
57
+ saved_dir, datatset_name, seq_name, "depth", all_img_names[idx][:-3] + "png"
58
+ )
59
+
60
+ copy_crop_files(
61
+ im_path=im_path,
62
+ depth_path=depth_path,
63
+ out_img_path=out_img_path,
64
+ out_depth_path=out_depth_path,
65
+ dataset=datatset_name,
66
+ )
67
+
68
+ origin_img = np.array(Image.open(im_path))
69
+ out_img_origin_path = osp.join(
70
+ saved_dir, datatset_name, seq_name, "color_origin", all_img_names[idx]
71
+ )
72
+ out_pose_path = osp.join(
73
+ saved_dir, datatset_name, seq_name, "pose", all_img_names[idx][:-3] + "txt"
74
+ )
75
+
76
+ os.makedirs(osp.dirname(out_img_origin_path), exist_ok=True)
77
+ os.makedirs(osp.dirname(out_pose_path), exist_ok=True)
78
+
79
+ cv2.imwrite(
80
+ out_img_origin_path,
81
+ origin_img,
82
+ )
83
+ shutil.copyfile(pose_path, out_pose_path)
84
+
85
+ intrinsic_path = osp.join(
86
+ root, seq_name, "intrinsic", "intrinsic_depth.txt"
87
+ )
88
+ out_intrinsic_path = osp.join(
89
+ saved_dir, datatset_name, seq_name, "intrinsic", "intrinsic_depth.txt"
90
+ )
91
+ os.makedirs(osp.dirname(out_intrinsic_path), exist_ok=True)
92
+ shutil.copyfile(intrinsic_path, out_intrinsic_path)
93
+
94
+ # 90 frames like DepthCraft
95
+ out_json_path = osp.join(saved_dir, datatset_name, "scannet_video.json")
96
+ gen_json(
97
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
98
+ start_id=0,end_id=90*3,step=3,
99
+ save_path=out_json_path,
100
+ )
101
+
102
+ #~500 frames in paper
103
+ out_json_path = osp.join(saved_dir, datatset_name, "scannet_video_500.json")
104
+ gen_json(
105
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
106
+ start_id=0,end_id=500,step=1,
107
+ save_path=out_json_path,
108
+ )
109
+
110
+ # tae
111
+ out_json_path = osp.join(saved_dir, datatset_name, "scannet_video_tae.json")
112
+ gen_json_scannet_tae(
113
+ root_path=osp.join(saved_dir, datatset_name),
114
+ start_id=0,end_id=192,step=1,
115
+ save_path=out_json_path,
116
+ )
117
+
118
+ if __name__ == "__main__":
119
+ extract_scannet(
120
+ root="path/to/scannet",
121
+ saved_dir="./benchmark/datasets/",
122
+ sample_len=-1,
123
+ datatset_name="scannet",
124
+ )
code_depth/benchmark/dataset_extract/dataset_extract_sintel.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
7
+
8
+
9
+ import os
10
+ import numpy as np
11
+ import os.path as osp
12
+ from PIL import Image
13
+ from tqdm import tqdm
14
+ import csv
15
+ import imageio
16
+ import cv2
17
+ import json
18
+ import glob
19
+ import shutil
20
+
21
+ from eval_utils import gen_json, get_sorted_files
22
+
23
+ TAG_FLOAT = 202021.25
24
+ TAG_CHAR = "PIEH"
25
+
26
+ def depth_read(filename):
27
+ """Read depth data from file, return as numpy array."""
28
+ f = open(filename, "rb")
29
+ check = np.fromfile(f, dtype=np.float32, count=1)[0]
30
+ assert (
31
+ check == TAG_FLOAT
32
+ ), " depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? ".format(
33
+ TAG_FLOAT, check
34
+ )
35
+ width = np.fromfile(f, dtype=np.int32, count=1)[0]
36
+ height = np.fromfile(f, dtype=np.int32, count=1)[0]
37
+ size = width * height
38
+ assert (
39
+ width > 0 and height > 0 and size > 1 and size < 100000000
40
+ ), " depth_read:: Wrong input size (width = {0}, height = {1}).".format(
41
+ width, height
42
+ )
43
+ depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
44
+ return depth
45
+
46
+ def extract_sintel(
47
+ root,
48
+ depth_root,
49
+ sample_len=-1,
50
+ datatset_name="",
51
+ saved_dir="",
52
+ ):
53
+ scenes_names = os.listdir(root)
54
+ all_samples = []
55
+ for i, seq_name in enumerate(tqdm(scenes_names)):
56
+ all_img_names = get_sorted_files(
57
+ os.path.join(root, seq_name), suffix=".png")
58
+
59
+ seq_len = len(all_img_names)
60
+ step = sample_len if sample_len > 0 else seq_len
61
+
62
+ for ref_idx in range(0, seq_len, step):
63
+ print(f"Progress: {seq_name}, {ref_idx // step} / {seq_len // step}")
64
+
65
+ if (ref_idx + step) <= seq_len:
66
+ ref_e = ref_idx + step
67
+ else:
68
+ continue
69
+
70
+ for idx in range(ref_idx, ref_e):
71
+ im_path = osp.join(
72
+ root, seq_name, all_img_names[idx]
73
+ )
74
+ depth_path = osp.join(
75
+ depth_root, seq_name, all_img_names[idx][:-3] + "dpt"
76
+ )
77
+ out_img_path = osp.join(
78
+ saved_dir, datatset_name,'clean', seq_name, all_img_names[idx]
79
+ )
80
+ out_depth_path = osp.join(
81
+ saved_dir, datatset_name,'depth', seq_name, all_img_names[idx][:-3] + "png"
82
+ )
83
+ depth = depth_read(depth_path)
84
+ img = np.array(Image.open(im_path))
85
+
86
+ os.makedirs(osp.dirname(out_img_path), exist_ok=True)
87
+ os.makedirs(osp.dirname(out_depth_path), exist_ok=True)
88
+
89
+ cv2.imwrite(
90
+ out_img_path,
91
+ img,
92
+ )
93
+ cv2.imwrite(
94
+ out_depth_path,
95
+ depth.astype(np.uint16)
96
+ )
97
+
98
+ gen_json(
99
+ root_path=osp.join(saved_dir, datatset_name), dataset=datatset_name,
100
+ start_id=0,end_id=100,step=1,
101
+ save_path=osp.join(saved_dir, datatset_name, "sintel_video.json"),)
102
+
103
+ if __name__ == "__main__":
104
+ extract_sintel(
105
+ root="path/to/training/clean",
106
+ depth_root="path/to/depth",
107
+ saved_dir="./benchmark/datasets/",
108
+ sample_len=-1,
109
+ datatset_name="sintel",
110
+ )
code_depth/benchmark/dataset_extract/eval_utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import os.path as osp
4
+ import json
5
+ import glob
6
+ import cv2
7
+ import shutil
8
+ from PIL import Image
9
+ from natsort import natsorted
10
+
11
+ def even_or_odd(num):
12
+ if num % 2 == 0:
13
+ return num
14
+ else:
15
+ return num - 1
16
+
17
+
18
+ def gen_json(root_path, dataset, start_id, end_id, step, save_path=None):
19
+ rgb_name = "rgb"
20
+ if dataset == "kitti":
21
+ factor = 256.0
22
+ elif dataset == "nyuv2":
23
+ factor = 6000.0
24
+ elif dataset == "bonn":
25
+ factor = 5000.0
26
+ elif dataset == 'sintel':
27
+ factor = 65535 / 650
28
+ rgb_name = "clean"
29
+ elif dataset == 'scannet':
30
+ factor = 1000.0
31
+ rgb_name = "color"
32
+ else:
33
+ raise NotImplementedError
34
+
35
+ data = {}
36
+ data[dataset] = []
37
+ pieces = glob.glob(osp.join(root_path, "*"))
38
+ count = 0
39
+ for piece in pieces:
40
+ if not osp.isdir(piece):
41
+ continue
42
+ name = piece.split('/')[-1]
43
+ name_dict = {name:[]}
44
+ images = glob.glob(osp.join(piece, rgb_name, "*.png")) + glob.glob(osp.join(piece, rgb_name, "*.jpg"))
45
+ images = natsorted(images)
46
+ depths = glob.glob(osp.join(piece, "depth/*.png"))
47
+ depths = natsorted(depths)
48
+ images = images[start_id:end_id:step]
49
+ depths = depths[start_id:end_id:step]
50
+
51
+ for i in range(len(images)):
52
+ image = images[i]
53
+ xx = image[len(root_path)+1:]
54
+ depth = depths[i][len(root_path)+1:]
55
+ tmp = {}
56
+ tmp["image"] = xx
57
+ tmp["gt_depth"] = depth
58
+ tmp["factor"] = factor
59
+ name_dict[name].append(tmp)
60
+ data[dataset].append(name_dict)
61
+ with open(save_path, "w") as f:
62
+ json.dump(data, f, indent= 4)
63
+
64
+
65
+ def gen_json_scannet_tae(root_path, start_id, end_id, step, save_path=None):
66
+ data = {}
67
+ data["scannet"] = []
68
+ pieces = glob.glob(osp.join(root_path, "*"))
69
+
70
+ color = 'color_origin'
71
+
72
+ for piece in pieces:
73
+ if not osp.isdir(piece):
74
+ continue
75
+ name = piece.split('/')[-1]
76
+ name_dict = {name:[]}
77
+ images = glob.glob(osp.join(piece,color, "*.jpg"))
78
+ images = natsorted(images)
79
+ depths = glob.glob(osp.join(piece, "depth/*.png"))
80
+ depths = natsorted(depths)
81
+ images = images[start_id:end_id:step]
82
+ depths = depths[start_id:end_id:step]
83
+ print(f"sequence frame number: {piece}")
84
+ count = 0
85
+ for i in range(len(images)):
86
+ image = images[i]
87
+ xx = image[len(root_path)+1:]
88
+ depth = depths[i][len(root_path)+1:]
89
+
90
+ base_path = osp.dirname(image)
91
+ base_path = base_path.replace(color, 'intrinsic')
92
+ K = np.loadtxt(base_path + '/intrinsic_depth.txt')
93
+
94
+ pose_path = image.replace(color, 'pose').replace('.jpg', '.txt')
95
+ pose = np.loadtxt(pose_path)
96
+
97
+ tmp = {}
98
+ tmp["image"] = xx
99
+ tmp["gt_depth"] = depth
100
+ tmp["factor"] = 1000.0
101
+ tmp["K"] = K.tolist()
102
+ tmp["pose"] = pose.tolist()
103
+ name_dict[name].append(tmp)
104
+ data["scannet"].append(name_dict)
105
+
106
+ with open(save_path, "w") as f:
107
+ json.dump(data, f, indent= 4)
108
+
109
+
110
+ def get_sorted_files(root_path, suffix):
111
+ all_img_names = os.listdir(root_path)
112
+ all_img_names = [x for x in all_img_names if x.endswith(suffix)]
113
+ print(f"sequence frame number: {len(all_img_names)}")
114
+
115
+ all_img_names.sort()
116
+ all_img_names = sorted(all_img_names, key=lambda x: int(x.split(".")[0][-4:]))
117
+
118
+ return all_img_names
119
+
120
+ def copy_crop_files(im_path, depth_path, out_img_path, out_depth_path, dataset):
121
+ img = np.array(Image.open(im_path))
122
+
123
+ if dataset == "kitti" or dataset == "bonn":
124
+ height, width = img.shape[:2]
125
+ height = even_or_odd(height)
126
+ width = even_or_odd(width)
127
+ img = img[:height, :width]
128
+ elif dataset == "nyuv2":
129
+ img = img[45:471, 41:601, :]
130
+ elif dataset == "scannet":
131
+ img = img[8:-8, 11:-11, :]
132
+
133
+ os.makedirs(osp.dirname(out_img_path), exist_ok=True)
134
+ os.makedirs(osp.dirname(out_depth_path), exist_ok=True)
135
+ cv2.imwrite(
136
+ out_img_path,
137
+ img,
138
+ )
139
+ shutil.copyfile(depth_path, out_depth_path)
140
+
code_depth/benchmark/eval/eval.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+ import json
6
+
7
+ import argparse
8
+ from scipy.ndimage import map_coordinates
9
+ from tqdm import tqdm
10
+ import os
11
+ import gc
12
+
13
+ import torch
14
+ from metric import *
15
+ import metric
16
+
17
+ device = 'cuda'
18
+ eval_metrics = [
19
+ "abs_relative_difference",
20
+ "rmse_linear",
21
+ "delta1_acc",
22
+ ]
23
+
24
+ def get_infer(infer_path,args, target_size = None):
25
+ if infer_path.split('.')[-1] == 'npy':
26
+ img_gray = np.load(infer_path)
27
+ img_gray = img_gray.astype(np.float32)
28
+ infer_factor = 1.0
29
+ else:
30
+ img = cv2.imread(infer_path)
31
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
32
+ img_gray = img_gray.astype(np.float32)
33
+ infer_factor = 1.0 / 255.0
34
+
35
+ infer = img_gray / infer_factor
36
+
37
+ if target_size is not None:
38
+ if infer.shape[0] != target_size[0] or infer.shape[1] != target_size[1]:
39
+ infer = cv2.resize(infer, (target_size[1], target_size[0]))
40
+ return infer
41
+
42
+ def get_gt(depth_gt_path, gt_factor, args):
43
+ if depth_gt_path.split('.')[-1] == 'npy':
44
+ depth_gt = np.load(depth_gt_path)
45
+ else:
46
+ depth_gt = cv2.imread(depth_gt_path, -1)
47
+ depth_gt = np.array(depth_gt)
48
+ depth_gt = depth_gt / gt_factor
49
+ depth_gt[depth_gt==0] = -1
50
+ return depth_gt
51
+
52
+ def get_flow(flow_path):
53
+ assert os.path.exists(flow_path)
54
+ flow = np.load(flow_path, allow_pickle=True)
55
+ return flow
56
+ def depth2disparity(depth, return_mask=False):
57
+ if isinstance(depth, np.ndarray):
58
+ disparity = np.zeros_like(depth)
59
+ non_negtive_mask = depth > 0
60
+ disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
61
+ if return_mask:
62
+ return disparity, non_negtive_mask
63
+ else:
64
+ return disparity
65
+
66
+ def eval_depthcrafter(infer_paths, depth_gt_paths, factors, args):
67
+ depth_errors = []
68
+ gts = []
69
+ infs = []
70
+ seq_length = args.max_eval_len
71
+ dataset_max_depth = args.max_depth_eval
72
+ for i in range(len(infer_paths)):
73
+ if not os.path.exists(infer_paths[i]):
74
+ continue
75
+ depth_gt = get_gt(depth_gt_paths[i], factors[i], args)
76
+ depth_gt = depth_gt[args.a:args.b, args.c:args.d]
77
+
78
+ infer = get_infer(infer_paths[i], args, target_size=depth_gt.shape)
79
+ gts.append(depth_gt)
80
+ infs.append(infer)
81
+ gts = np.stack(gts, axis=0)
82
+
83
+ infs = np.stack(infs, axis=0)
84
+ infs = infs[:seq_length]
85
+ gts = gts[:seq_length]
86
+ valid_mask = np.logical_and((gts>1e-3), (gts<dataset_max_depth))
87
+
88
+ gt_disp_masked = 1. / (gts[valid_mask].reshape((-1,1)).astype(np.float64) + 1e-8)
89
+ infs = np.clip(infs, a_min=1e-3, a_max=None)
90
+ pred_disp_masked = infs[valid_mask].reshape((-1,1)).astype(np.float64)
91
+
92
+ _ones = np.ones_like(pred_disp_masked)
93
+ A = np.concatenate([pred_disp_masked, _ones], axis=-1)
94
+ X = np.linalg.lstsq(A, gt_disp_masked, rcond=None)[0]
95
+ scale, shift = X
96
+ aligned_pred = scale * infs + shift
97
+ aligned_pred = np.clip(aligned_pred, a_min=1e-3, a_max=None)
98
+
99
+ pred_depth = depth2disparity(aligned_pred)
100
+ gt_depth = gts
101
+ pred_depth = np.clip(
102
+ pred_depth, a_min=1e-3, a_max=dataset_max_depth
103
+ )
104
+ sample_metric = []
105
+ metric_funcs = [getattr(metric, _met) for _met in eval_metrics]
106
+
107
+ pred_depth_ts = torch.from_numpy(pred_depth).to(device)
108
+ gt_depth_ts = torch.from_numpy(gt_depth).to(device)
109
+ valid_mask_ts = torch.from_numpy(valid_mask).to(device)
110
+
111
+ n = valid_mask.sum((-1, -2))
112
+ valid_frame = (n > 0)
113
+ pred_depth_ts = pred_depth_ts[valid_frame]
114
+ gt_depth_ts = gt_depth_ts[valid_frame]
115
+ valid_mask_ts = valid_mask_ts[valid_frame]
116
+
117
+ for met_func in metric_funcs:
118
+ _metric_name = met_func.__name__
119
+ _metric = met_func(pred_depth_ts, gt_depth_ts, valid_mask_ts).item()
120
+ sample_metric.append(_metric)
121
+ return sample_metric
122
+
123
+
124
+ def main():
125
+
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument('--infer_path', type=str, default='')
128
+ parser.add_argument('--infer_type', type=str, default='npy')
129
+ parser.add_argument('--benchmark_path', type=str, default='')
130
+ parser.add_argument('--datasets', type=str, nargs='+', default=['vkitti', 'kitti', 'sintel', 'nyu_v2', 'tartanair', 'bonn', 'ip_lidar'])
131
+
132
+ args = parser.parse_args()
133
+
134
+ results_save_path = os.path.join(args.infer_path, 'results.txt')
135
+
136
+ for dataset in args.datasets:
137
+
138
+ file = open(results_save_path, 'a')
139
+
140
+ if dataset == 'kitti':
141
+ args.json_file = os.path.join(args.benchmark_path,'kitti/kitti_video.json')
142
+ args.root_path = os.path.join(args.benchmark_path,'kitti')
143
+ args.max_depth_eval = 80.0
144
+ args.min_depth_eval = 0.1
145
+ args.max_eval_len = 110
146
+ args.a = 0
147
+ args.b = 374
148
+ args.c = 0
149
+ args.d = 1242
150
+ if dataset == 'kitti_500':
151
+ dataset = 'kitti'
152
+ args.json_file = os.path.join(args.benchmark_path,'kitti/kitti_video_500.json')
153
+ args.root_path = os.path.join(args.benchmark_path,'kitti')
154
+ args.max_depth_eval = 80.0
155
+ args.min_depth_eval = 0.1
156
+ args.max_eval_len = 500
157
+ args.a = 0
158
+ args.b = 374
159
+ args.c = 0
160
+ args.d = 1242
161
+ elif dataset == 'sintel':
162
+ args.json_file = os.path.join(args.benchmark_path,'sintel/sintel_video.json')
163
+ args.root_path = os.path.join(args.benchmark_path,'sintel')
164
+ args.max_depth_eval = 70
165
+ args.min_depth_eval = 0.1
166
+ args.max_eval_len = 100
167
+ args.a = 0
168
+ args.b = 436
169
+ args.c = 0
170
+ args.d = 1024
171
+ elif dataset == 'nyuv2_500':
172
+ dataset = 'nyuv2'
173
+ args.json_file = os.path.join(args.benchmark_path,'nyuv2/nyuv2_video_500.json')
174
+ args.root_path = os.path.join(args.benchmark_path,'nyuv2')
175
+ args.max_depth_eval = 10.0
176
+ args.min_depth_eval = 0.1
177
+ args.max_eval_len = 500
178
+ args.a = 45
179
+ args.b = 471
180
+ args.c = 41
181
+ args.d = 601
182
+ elif dataset == 'bonn':
183
+ args.json_file = os.path.join(args.benchmark_path,'bonn/bonn_video.json')
184
+ args.root_path = os.path.join(args.benchmark_path,'bonn')
185
+ args.max_depth_eval = 10.0
186
+ args.min_depth_eval = 0.1
187
+ args.max_eval_len = 110
188
+ args.a = 0
189
+ args.b = 480
190
+ args.c = 0
191
+ args.d = 640
192
+ elif dataset == 'bonn_500':
193
+ dataset = 'bonn'
194
+ args.json_file = os.path.join(args.benchmark_path,'bonn/bonn_video_500.json')
195
+ args.root_path = os.path.join(args.benchmark_path,'bonn')
196
+ args.max_depth_eval = 10.0
197
+ args.min_depth_eval = 0.1
198
+ args.max_eval_len = 500
199
+ args.a = 0
200
+ args.b = 480
201
+ args.c = 0
202
+ args.d = 640
203
+ elif dataset == 'scannet':
204
+ args.json_file = os.path.join(args.benchmark_path,'scannet/scannet_video.json')
205
+ args.root_path = os.path.join(args.benchmark_path,'scannet')
206
+ args.max_depth_eval = 10.0
207
+ args.min_depth_eval = 0.1
208
+ args.max_eval_len = 90
209
+ args.a = 8
210
+ args.b = -8
211
+ args.c = 11
212
+ args.d = -11
213
+ elif dataset == 'scannet_500':
214
+ dataset = 'scannet'
215
+ args.json_file = os.path.join(args.benchmark_path,'scannet/scannet_video_500.json')
216
+ args.root_path = os.path.join(args.benchmark_path,'scannet')
217
+ args.max_depth_eval = 10.0
218
+ args.min_depth_eval = 0.1
219
+ args.max_eval_len = 500
220
+ args.a = 8
221
+ args.b = -8
222
+ args.c = 11
223
+ args.d = -11
224
+
225
+ with open(args.json_file, 'r') as fs:
226
+ path_json = json.load(fs)
227
+
228
+ json_data = path_json[dataset]
229
+ scale_stds = shift_stds = stable_result_fulls = stable_result_wins = 0
230
+ depth_result_fulls = np.zeros(5)
231
+ depth_result_wins = np.zeros(5)
232
+ depth_result_onlys = np.zeros(5)
233
+ count = 0
234
+ line = '-' * 50
235
+ print(f'<{line} {dataset} start {line}>')
236
+ file.write(f'<{line} {dataset} start {line}>\n')
237
+ results_all = []
238
+ for data in tqdm(json_data):
239
+ for key in data.keys():
240
+ value = data[key]
241
+ infer_paths = []
242
+ depth_gt_paths = []
243
+ flow_paths = []
244
+ factors = []
245
+ for images in value:
246
+ infer_path = (args.infer_path + '/'+ dataset + '/' + images['image']).replace('.jpg', '.npy').replace('.png', '.npy')
247
+
248
+ infer_paths.append(infer_path)
249
+ depth_gt_paths.append(args.root_path + '/' + images['gt_depth'])
250
+ factors.append(images['factor'])
251
+ infer_paths = infer_paths[:args.max_eval_len]
252
+ depth_gt_paths = depth_gt_paths[:args.max_eval_len]
253
+ factors = factors[:args.max_eval_len]
254
+ results_single = eval_depthcrafter(infer_paths, depth_gt_paths, factors, args)
255
+ results_all.append(results_single)
256
+ final_results = np.array(results_all)
257
+ final_results_mean = np.mean(final_results, axis=0)
258
+ result_dict = { 'name': dataset }
259
+ for i, metric in enumerate(eval_metrics):
260
+ result_dict[metric] = final_results_mean[i]
261
+ print(f"{metric}: {final_results_mean[i]:04f}")
262
+ file.write(f"{metric}: {final_results_mean[i]:04f}\n")
263
+ file.write(f'<{line} {dataset} finish {line}>\n')
264
+ if __name__ == '__main__':
265
+ main()
code_depth/benchmark/eval/eval.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -x
3
+ set -e
4
+
5
+ pred_disp_root=$1 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] prediction
6
+ benchmark_root=$2 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] ground truth
7
+
8
+ #eval sintel
9
+ python3 benchmark/eval/eval.py \
10
+ --infer_path $pred_disp_root \
11
+ --benchmark_path $benchmark_root \
12
+ --datasets sintel
13
+
14
+ #eval scannet
15
+ python3 benchmark/eval/eval.py \
16
+ --infer_path $pred_disp_root \
17
+ --benchmark_path $benchmark_root \
18
+ --datasets scannet
19
+
20
+ #eval kitti
21
+ python3 benchmark/eval/eval.py \
22
+ --infer_path $pred_disp_root \
23
+ --benchmark_path $benchmark_root \
24
+ --datasets kitti
25
+
26
+ #eval bonn
27
+ python3 benchmark/eval/eval.py \
28
+ --infer_path $pred_disp_root \
29
+ --benchmark_path $benchmark_root \
30
+ --datasets bonn
code_depth/benchmark/eval/eval_500.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -x
3
+ set -e
4
+
5
+ pred_disp_root=$1 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] prediction
6
+ benchmark_root=$2 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] ground truth
7
+
8
+ #eval scannet
9
+ python3 benchmark/eval/eval.py \
10
+ --infer_path $pred_disp_root \
11
+ --benchmark_path $benchmark_root \
12
+ --datasets scannet_500
13
+
14
+ #eval kitti
15
+ python3 benchmark/eval/eval.py \
16
+ --infer_path $pred_disp_root \
17
+ --benchmark_path $benchmark_root \
18
+ --datasets kitti_500
19
+
20
+ #eval bonn
21
+ python3 benchmark/eval/eval.py \
22
+ --infer_path $pred_disp_root \
23
+ --benchmark_path $benchmark_root \
24
+ --datasets bonn_500
25
+
26
+ #eval nyu
27
+ python3 benchmark/eval/eval.py \
28
+ --infer_path $pred_disp_root \
29
+ --benchmark_path $benchmark_root \
30
+ --datasets nyuv2_500
code_depth/benchmark/eval/eval_tae.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+ import json
5
+ import argparse
6
+ from scipy.ndimage import map_coordinates
7
+ from tqdm import tqdm
8
+ import os
9
+ import gc
10
+ import time
11
+ import torch
12
+
13
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+
15
+ def compute_errors_torch(gt, pred):
16
+ abs_rel = torch.mean(torch.abs(gt - pred) / gt)
17
+ return abs_rel
18
+
19
+ def get_infer(infer_path,args, target_size = None):
20
+ if infer_path.split('.')[-1] == 'npy':
21
+ img_gray = np.load(infer_path)
22
+ img_gray = img_gray.astype(np.float32)
23
+ infer_factor = 1.0
24
+ else:
25
+ img = cv2.imread(infer_path)
26
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
27
+ img_gray = img_gray.astype(np.float32)
28
+ infer_factor = 1.0 / 255.0
29
+
30
+ infer = img_gray / infer_factor
31
+ if args.hard_crop:
32
+ infer = infer[args.a:args.b, args.c:args.d]
33
+
34
+ if target_size is not None:
35
+ if infer.shape[0] != target_size[0] or infer.shape[1] != target_size[1]:
36
+ infer = cv2.resize(infer, (target_size[1], target_size[0]))
37
+ return infer
38
+
39
+ def get_gt(depth_gt_path, gt_factor, args):
40
+ if depth_gt_path.split('.')[-1] == 'npy':
41
+ depth_gt = np.load(depth_gt_path)
42
+ else:
43
+ depth_gt = cv2.imread(depth_gt_path, -1)
44
+ depth_gt = np.array(depth_gt)
45
+ depth_gt = depth_gt / gt_factor
46
+
47
+ depth_gt[depth_gt==0] = 0
48
+ return depth_gt
49
+
50
+ def depth2disparity(depth, return_mask=False):
51
+ if isinstance(depth, np.ndarray):
52
+ disparity = np.zeros_like(depth)
53
+ non_negtive_mask = depth > 0
54
+ disparity[non_negtive_mask] = 1.0 / depth[non_negtive_mask]
55
+ if return_mask:
56
+ return disparity, non_negtive_mask
57
+ else:
58
+ return disparity
59
+
60
+ def tae_torch(depth1, depth2, R_2_1, T_2_1, K, mask):
61
+ H, W = depth1.shape
62
+ fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
63
+
64
+ # Generate meshgrid
65
+ xx, yy = torch.meshgrid(torch.arange(W), torch.arange(H))
66
+ xx, yy = xx.t(), yy.t() # Transpose to match the shape (H, W)
67
+
68
+ # Convert meshgrid to tensor
69
+ xx = xx.to(dtype=depth1.dtype, device=depth1.device)
70
+ yy = yy.to(dtype=depth1.dtype, device=depth1.device)
71
+ # Calculate 3D points in frame 1
72
+ X = (xx - cx) * depth1 / fx
73
+ Y = (yy - cy) * depth1 / fy
74
+ Z = depth1
75
+ points3d = torch.stack((X.flatten(), Y.flatten(), Z.flatten()), dim=1) # Shape (H*W, 3)
76
+ T = torch.tensor(T_2_1, dtype=depth1.dtype, device=depth1.device)
77
+
78
+ # Transform 3D points to frame 2
79
+ points3d_transformed = torch.matmul(points3d, R_2_1.T) + T
80
+ X_world, Y_world, Z_world = points3d_transformed[:, 0], points3d_transformed[:, 1], points3d_transformed[:, 2]
81
+ # Project 3D points to 2D plane using intrinsic matrix
82
+ X_plane = (X_world * fx) / Z_world + cx
83
+ Y_plane = (Y_world * fy) / Z_world + cy
84
+
85
+ # Round and convert to integers
86
+ X_plane = torch.round(X_plane).to(dtype=torch.long)
87
+ Y_plane = torch.round(Y_plane).to(dtype=torch.long)
88
+
89
+ # Filter valid indices
90
+ valid_mask = (X_plane >= 0) & (X_plane < W) & (Y_plane >= 0) & (Y_plane < H)
91
+ if valid_mask.sum() == 0:
92
+ return 0
93
+
94
+ depth_proj = torch.zeros((H, W), dtype=depth1.dtype, device=depth1.device)
95
+
96
+ valid_X = X_plane[valid_mask]
97
+ valid_Y = Y_plane[valid_mask]
98
+ valid_Z = Z_world[valid_mask]
99
+
100
+ depth_proj[valid_Y, valid_X] = valid_Z
101
+
102
+ valid_mask = (depth_proj > 0) & (depth2 > 0) & (mask)
103
+ if valid_mask.sum() == 0:
104
+ return 0
105
+ abs_errors = compute_errors_torch(depth2[valid_mask], depth_proj[valid_mask])
106
+
107
+ return abs_errors
108
+
109
+ def eval_TAE(infer_paths, depth_gt_paths, factors, masks, Ks, poses, args):
110
+ gts = []
111
+ infs = []
112
+ dataset_max_depth = args.max_depth_eval
113
+ gt_paths_cur = []
114
+ Ks_cur = []
115
+ poses_cur = []
116
+ masks_cur = []
117
+
118
+ for i in range(len(infer_paths)):
119
+ # DAV missing some frames
120
+ if not os.path.exists(infer_paths[i]):
121
+ continue
122
+
123
+ depth_gt = get_gt(depth_gt_paths[i], factors[i], args)
124
+ depth_gt = depth_gt[args.a:args.b, args.c:args.d]
125
+
126
+ gt_paths_cur.append(depth_gt_paths[i])
127
+ infer = get_infer(infer_paths[i], args, target_size=depth_gt.shape)
128
+
129
+ gts.append(depth_gt)
130
+ infs.append(infer)
131
+ Ks_cur.append(Ks[i])
132
+ poses_cur.append(poses[i])
133
+ if args.mask:
134
+ masks_cur.append(masks[i])
135
+
136
+ gts = np.stack(gts, axis=0)
137
+ infs = np.stack(infs, axis=0)
138
+
139
+ valid_mask = np.logical_and((gts>1e-3), (gts<dataset_max_depth))
140
+
141
+ gt_disp_masked = 1. / (gts[valid_mask].reshape((-1,1)).astype(np.float64) + 1e-8)
142
+ infs = np.clip(infs, a_min=1e-3, a_max=None)
143
+ pred_disp_masked = infs[valid_mask].reshape((-1,1)).astype(np.float64)
144
+
145
+ _ones = np.ones_like(pred_disp_masked)
146
+ A = np.concatenate([pred_disp_masked, _ones], axis=-1)
147
+ X = np.linalg.lstsq(A, gt_disp_masked, rcond=None)[0]
148
+ scale, shift = X
149
+
150
+ aligned_pred = scale * infs + shift
151
+ aligned_pred = np.clip(aligned_pred, a_min=1e-3, a_max=None)
152
+
153
+ pred_depth = depth2disparity(aligned_pred)
154
+ gt_depth = gts
155
+ pred_depth = np.clip(
156
+ pred_depth, a_min=1e-3, a_max=dataset_max_depth
157
+ )
158
+
159
+ error_sum = 0.
160
+ for i in range(len(gt_paths_cur) -1):
161
+ depth1 = pred_depth[i]
162
+ depth2 = pred_depth[i+1]
163
+
164
+ gt_depth1 = gt_paths_cur[i]
165
+ gt_depth2 = gt_paths_cur[i+1]
166
+ T_1 = poses_cur[i]
167
+ T_2 = poses_cur[i+1]
168
+
169
+ T_2_1 = np.linalg.inv(T_2) @ T_1
170
+
171
+ R_2_1 = T_2_1[:3,:3]
172
+ t_2_1 = T_2_1[:3, 3]
173
+ K = Ks_cur[i]
174
+
175
+ if args.mask:
176
+ mask_path1 = masks_cur[i]
177
+ mask_path2 = masks_cur[i+1]
178
+ mask1 = cv2.imread(mask_path1, -1)
179
+ mask2 = cv2.imread(mask_path2, -1)
180
+ mask1 = mask1[args.a:args.b, args.c:args.d]
181
+ if mask2 is None:
182
+ mask2 = np.ones_like(mask1)
183
+ else:
184
+ mask2 = mask2[args.a:args.b, args.c:args.d]
185
+
186
+ mask1 = mask1 > 0
187
+ mask2 = mask2 > 0
188
+ else:
189
+ mask1 = np.ones_like(depth1)
190
+ mask2 = np.ones_like(depth2)
191
+
192
+ mask1 = mask1 > 0
193
+ mask2 = mask2 > 0
194
+
195
+ depth1 = torch.from_numpy(depth1).to(device=device)
196
+ depth2 = torch.from_numpy(depth2).to(device=device)
197
+ R_2_1 = torch.from_numpy(R_2_1).to(device=device)
198
+ t_2_1 = torch.from_numpy(t_2_1).to(device=device)
199
+ mask1 = torch.from_numpy(mask1).to(device=device)
200
+ mask2 = torch.from_numpy(mask2).to(device=device)
201
+
202
+ error1 = tae_torch(depth1, depth2, R_2_1, t_2_1, K, mask2)
203
+ T_1_2 = np.linalg.inv(T_2_1)
204
+ R_1_2 = T_1_2[:3,:3]
205
+ t_1_2 = T_1_2[:3, 3]
206
+
207
+ R_1_2 = torch.from_numpy(R_1_2).to(device=device)
208
+ t_1_2 = torch.from_numpy(t_1_2).to(device=device)
209
+
210
+ error2 = tae_torch(depth2, depth1, R_1_2, t_1_2, K, mask1)
211
+
212
+ error_sum += error1
213
+ error_sum += error2
214
+
215
+ gc.collect()
216
+ result = error_sum / (2 * (len(gt_paths_cur) -1))
217
+ return result*100
218
+
219
+
220
+ if __name__ == '__main__':
221
+ parser = argparse.ArgumentParser()
222
+ parser.add_argument('--infer_path', type=str, default='')
223
+ parser.add_argument('--benchmark_path', type=str, default='')
224
+
225
+ parser.add_argument('--datasets', type=str, nargs='+', default=['scannet', 'sintel'])
226
+ parser.add_argument('--start_idx', type=int, default=0)
227
+ parser.add_argument('--end_idx', type=int, default=180)
228
+ parser.add_argument('--eval_scenes_num', type=int, default=20)
229
+ parser.add_argument('--hard_crop', action='store_true', default=False)
230
+
231
+ args = parser.parse_args()
232
+
233
+ results_save_path = os.path.join(args.infer_path, 'results.txt')
234
+
235
+ for dataset in args.datasets:
236
+
237
+ file = open(results_save_path, 'a')
238
+ if dataset == 'scannet':
239
+ args.json_file = os.path.join(args.benchmark_path,'scannet/scannet_video.json')
240
+ args.root_path = os.path.join(args.benchmark_path, 'scannet/')
241
+ args.max_depth_eval = 10.0
242
+ args.min_depth_eval = 0.1
243
+ args.max_eval_len = 200
244
+ args.mask = False
245
+ #DepthCrafer crop
246
+ args.a = 8
247
+ args.b = -8
248
+ args.c = 11
249
+ args.d = -11
250
+
251
+ with open(args.json_file, 'r') as fs:
252
+ path_json = json.load(fs)
253
+
254
+ json_data = path_json[dataset]
255
+ count = 0
256
+ line = '-' * 50
257
+ print(f'<{line} {dataset} start {line}>')
258
+ file.write(f'<{line} {dataset} start {line}>\n')
259
+ results_all = 0.
260
+
261
+ for data in tqdm(json_data[:args.eval_scenes_num]):
262
+ for scene_name in data.keys():
263
+ value = data[scene_name]
264
+ infer_paths = []
265
+ depth_gt_paths = []
266
+ factors = []
267
+ Ks = []
268
+ poses = []
269
+ masks = []
270
+ for images in value:
271
+ infer_path = (args.infer_path + '/'+ dataset + '/' + images['image']).replace('.jpg', '.npy').replace('.png', '.npy')
272
+
273
+ infer_paths.append(infer_path)
274
+ depth_gt_paths.append(args.root_path + '/' + images['gt_depth'])
275
+ factors.append(images['factor'])
276
+ Ks.append(np.array(images['K']))
277
+ poses.append(np.array(images['pose']))
278
+
279
+ if args.mask:
280
+ masks.append(args.root_path + '/' + images['mask'])
281
+
282
+ infer_paths = infer_paths[args.start_idx:args.end_idx]
283
+ depth_gt_paths = depth_gt_paths[args.start_idx:args.end_idx]
284
+ factors = factors[args.start_idx:args.end_idx]
285
+ poses = poses[args.start_idx:args.end_idx]
286
+ Ks = Ks[args.start_idx:args.end_idx]
287
+ error = eval_TAE(infer_paths, depth_gt_paths, factors,masks,Ks,poses,args)
288
+ results_all += error
289
+ count += 1
290
+
291
+ print(dataset,': ','tae ', results_all / count)
292
+ file.write(f'{dataset}: {results_all / count}\n')
293
+ file.write(f'<{line} {dataset} finish {line}>\n')
294
+
295
+
code_depth/benchmark/eval/eval_tae.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -x
3
+ set -e
4
+
5
+ pred_disp_root=$1 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] prediction
6
+ benchmark_root=$2 # The parent directory that contaning [sintel, scannet, KITTI, bonn, NYUv2] ground truth
7
+
8
+ #eval scannet
9
+ python3 benchmark/eval/eval_tae.py \
10
+ --infer_path $pred_disp_root \
11
+ --benchmark_path $benchmark_root \
12
+ --datasets scannet \
13
+ --start_idx 10 \
14
+ --end_idx 180 \
15
+ --eval_scenes_num 20 \
16
+ --hard_crop
17
+
18
+
code_depth/benchmark/eval/metric.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def abs_relative_difference(output, target, valid_mask=None):
4
+ actual_output = output
5
+ actual_target = target
6
+ abs_relative_diff = torch.abs(actual_output - actual_target) / actual_target
7
+ if valid_mask is not None:
8
+ abs_relative_diff[~valid_mask] = 0
9
+ n = valid_mask.sum((-1, -2))
10
+ else:
11
+ n = output.shape[-1] * output.shape[-2]
12
+ abs_relative_diff = torch.sum(abs_relative_diff, (-1, -2)) / n
13
+ return abs_relative_diff.mean()
14
+
15
+ def squared_relative_difference(output, target, valid_mask=None):
16
+ actual_output = output
17
+ actual_target = target
18
+ square_relative_diff = (
19
+ torch.pow(torch.abs(actual_output - actual_target), 2) / actual_target
20
+ )
21
+ if valid_mask is not None:
22
+ square_relative_diff[~valid_mask] = 0
23
+ n = valid_mask.sum((-1, -2))
24
+ else:
25
+ n = output.shape[-1] * output.shape[-2]
26
+ square_relative_diff = torch.sum(square_relative_diff, (-1, -2)) / n
27
+ return square_relative_diff.mean()
28
+
29
+ def rmse_linear(output, target, valid_mask=None):
30
+ actual_output = output
31
+ actual_target = target
32
+ diff = actual_output - actual_target
33
+ if valid_mask is not None:
34
+ diff[~valid_mask] = 0
35
+ n = valid_mask.sum((-1, -2))
36
+ else:
37
+ n = output.shape[-1] * output.shape[-2]
38
+ diff2 = torch.pow(diff, 2)
39
+ mse = torch.sum(diff2, (-1, -2)) / n
40
+ rmse = torch.sqrt(mse)
41
+ return rmse.mean()
42
+
43
+ def rmse_log(output, target, valid_mask=None):
44
+ diff = torch.log(output) - torch.log(target)
45
+ if valid_mask is not None:
46
+ diff[~valid_mask] = 0
47
+ n = valid_mask.sum((-1, -2))
48
+ else:
49
+ n = output.shape[-1] * output.shape[-2]
50
+ diff2 = torch.pow(diff, 2)
51
+ mse = torch.sum(diff2, (-1, -2)) / n # [B]
52
+ rmse = torch.sqrt(mse)
53
+ return rmse.mean()
54
+
55
+ def log10(output, target, valid_mask=None):
56
+ if valid_mask is not None:
57
+ diff = torch.abs(
58
+ torch.log10(output[valid_mask]) - torch.log10(target[valid_mask])
59
+ )
60
+ else:
61
+ diff = torch.abs(torch.log10(output) - torch.log10(target))
62
+ return diff.mean()
63
+
64
+ # adapt from: https://github.com/imran3180/depth-map-prediction/blob/master/main.py
65
+ def threshold_percentage(output, target, threshold_val, valid_mask=None):
66
+ d1 = output / target
67
+ d2 = target / output
68
+ max_d1_d2 = torch.max(d1, d2)
69
+ zero = torch.zeros(*output.shape)
70
+ one = torch.ones(*output.shape)
71
+ bit_mat = torch.where(max_d1_d2.cpu() < threshold_val, one, zero)
72
+ if valid_mask is not None:
73
+ bit_mat[~valid_mask] = 0
74
+ n = valid_mask.sum((-1, -2))
75
+ else:
76
+ n = output.shape[-1] * output.shape[-2]
77
+ count_mat = torch.sum(bit_mat, (-1, -2))
78
+ threshold_mat = count_mat / n.cpu()
79
+ return threshold_mat.mean()
80
+
81
+ def delta1_acc(pred, gt, valid_mask):
82
+ return threshold_percentage(pred, gt, 1.25, valid_mask)
83
+
84
+ def delta2_acc(pred, gt, valid_mask):
85
+ return threshold_percentage(pred, gt, 1.25**2, valid_mask)
86
+
87
+ def delta3_acc(pred, gt, valid_mask):
88
+ return threshold_percentage(pred, gt, 1.25**3, valid_mask)
89
+
90
+ def i_rmse(output, target, valid_mask=None):
91
+ output_inv = 1.0 / output
92
+ target_inv = 1.0 / target
93
+ diff = output_inv - target_inv
94
+ if valid_mask is not None:
95
+ diff[~valid_mask] = 0
96
+ n = valid_mask.sum((-1, -2))
97
+ else:
98
+ n = output.shape[-1] * output.shape[-2]
99
+ diff2 = torch.pow(diff, 2)
100
+ mse = torch.sum(diff2, (-1, -2)) / n # [B]
101
+ rmse = torch.sqrt(mse)
102
+ return rmse.mean()
103
+
104
+ def silog_rmse(depth_pred, depth_gt, valid_mask=None):
105
+ diff = torch.log(depth_pred) - torch.log(depth_gt)
106
+ if valid_mask is not None:
107
+ diff[~valid_mask] = 0
108
+ n = valid_mask.sum((-1, -2))
109
+ else:
110
+ n = depth_gt.shape[-2] * depth_gt.shape[-1]
111
+
112
+ diff2 = torch.pow(diff, 2)
113
+
114
+ first_term = torch.sum(diff2, (-1, -2)) / n
115
+ second_term = torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2)
116
+ loss = torch.sqrt(torch.mean(first_term - second_term)) * 100
117
+ return loss
code_depth/benchmark/infer/infer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import cv2
4
+ import json
5
+ import torch
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+
9
+ from video_depth_anything.video_depth import VideoDepthAnything
10
+ from utils.dc_utils import read_video_frames
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--infer_path', type=str, default='')
15
+
16
+ parser.add_argument('--json_file', type=str, default='')
17
+ parser.add_argument('--datasets', type=str, nargs='+', default=['scannet', 'nyuv2'])
18
+
19
+ parser.add_argument('--input_size', type=int, default=518)
20
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
21
+
22
+ args = parser.parse_args()
23
+
24
+ for dataset in args.datasets:
25
+
26
+ with open(args.json_file, 'r') as fs:
27
+ path_json = json.load(fs)
28
+
29
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ model_configs = {
32
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
33
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
+ }
35
+
36
+ video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
37
+ video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
38
+ video_depth_anything = video_depth_anything.to(DEVICE).eval()
39
+
40
+ json_data = path_json[dataset]
41
+ root_path = os.path.dirname(args.json_file)
42
+ for data in tqdm(json_data):
43
+ for key in data.keys():
44
+ value = data[key]
45
+ infer_paths = []
46
+
47
+ videos = []
48
+ for images in value:
49
+
50
+ image_path = os.path.join(root_path, images['image'])
51
+ infer_path = (args.infer_path + '/'+ dataset + '/' + images['image']).replace('.jpg', '.npy').replace('.png', '.npy')
52
+ infer_paths.append(infer_path)
53
+
54
+ img = cv2.imread(image_path)
55
+ videos.append(img)
56
+ videos = np.stack(videos, axis=0)
57
+ target_fps=1
58
+ depths, fps = video_depth_anything.infer_video_depth(videos, target_fps, input_size=args.input_size, device=DEVICE, fp32=True)
59
+
60
+ for i in range(len(infer_paths)):
61
+ infer_path = infer_paths[i]
62
+ os.makedirs(os.path.dirname(infer_path), exist_ok=True)
63
+ depth = depths[i]
64
+ np.save(infer_path, depth)
65
+
code_depth/get_weights.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ mkdir checkpoints
4
+ cd checkpoints
5
+ wget https://huggingface.co/depth-anything/Video-Depth-Anything-Small/resolve/main/video_depth_anything_vits.pth
6
+ wget https://huggingface.co/depth-anything/Video-Depth-Anything-Large/resolve/main/video_depth_anything_vitl.pth
code_depth/large_files.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ./checkpoints/video_depth_anything_vitl.pth
2
+ ./checkpoints/video_depth_anything_vits.pth
code_depth/requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.23.1
2
+ torch==2.1.1
3
+ torchvision==0.16.1
4
+ opencv-python
5
+ matplotlib
6
+ pillow
7
+ imageio==2.19.3
8
+ imageio-ffmpeg==0.4.7
9
+ decord
10
+ xformers==0.0.23
11
+ einops==0.4.1
12
+ easydict
13
+ tqdm
14
+ OpenEXR==3.3.1
code_depth/run.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import numpy as np
16
+ import os
17
+ import torch
18
+
19
+ from video_depth_anything.video_depth import VideoDepthAnything
20
+ from utils.dc_utils import read_video_frames, save_video
21
+
22
+ if __name__ == '__main__':
23
+ parser = argparse.ArgumentParser(description='Video Depth Anything')
24
+ parser.add_argument('--input_video', type=str, default='./assets/example_videos/davis_rollercoaster.mp4')
25
+ parser.add_argument('--output_dir', type=str, default='./outputs')
26
+ parser.add_argument('--input_size', type=int, default=518)
27
+ parser.add_argument('--max_res', type=int, default=1280)
28
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
29
+ parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
30
+ parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')
31
+ parser.add_argument('--fp32', action='store_true', help='model infer with torch.float32, default is torch.float16')
32
+ parser.add_argument('--grayscale', action='store_true', help='do not apply colorful palette')
33
+ parser.add_argument('--save_npz', action='store_true', help='save depths as npz')
34
+ parser.add_argument('--save_exr', action='store_true', help='save depths as exr')
35
+
36
+ args = parser.parse_args()
37
+
38
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
39
+
40
+ model_configs = {
41
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
42
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
43
+ }
44
+
45
+ video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
46
+ video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
47
+ video_depth_anything = video_depth_anything.to(DEVICE).eval()
48
+
49
+ frames, target_fps = read_video_frames(args.input_video, args.max_len, args.target_fps, args.max_res)
50
+ depths, fps = video_depth_anything.infer_video_depth(frames, target_fps, input_size=args.input_size, device=DEVICE, fp32=args.fp32)
51
+
52
+ video_name = os.path.basename(args.input_video)
53
+ if not os.path.exists(args.output_dir):
54
+ os.makedirs(args.output_dir)
55
+
56
+ processed_video_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_src.mp4')
57
+ depth_vis_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_vis.mp4')
58
+ save_video(frames, processed_video_path, fps=fps)
59
+ save_video(depths, depth_vis_path, fps=fps, is_depths=True, grayscale=args.grayscale)
60
+
61
+ if args.save_npz:
62
+ depth_npz_path = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_depths.npz')
63
+ np.savez_compressed(depth_npz_path, depths=depths)
64
+ if args.save_exr:
65
+ depth_exr_dir = os.path.join(args.output_dir, os.path.splitext(video_name)[0]+'_depths_exr')
66
+ os.makedirs(depth_exr_dir, exist_ok=True)
67
+ import OpenEXR
68
+ import Imath
69
+ for i, depth in enumerate(depths):
70
+ output_exr = f"{depth_exr_dir}/frame_{i:05d}.exr"
71
+ header = OpenEXR.Header(depth.shape[1], depth.shape[0])
72
+ header["channels"] = {
73
+ "Z": Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
74
+ }
75
+ exr_file = OpenEXR.OutputFile(output_exr, header)
76
+ exr_file.writePixels({"Z": depth.tobytes()})
77
+ exr_file.close()
78
+
79
+
80
+
81
+
code_depth/run_images_rord.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import numpy as np
16
+ import os
17
+ import torch
18
+ import os
19
+ import cv2
20
+ import numpy as np
21
+ import matplotlib.pyplot as plt
22
+ import matplotlib.cm as cm
23
+ from PIL import Image
24
+ from video_depth_anything.video_depth import VideoDepthAnything
25
+ from utils.dc_utils import read_video_frames, save_video
26
+ import tqdm
27
+
28
+ if __name__ == '__main__':
29
+ parser = argparse.ArgumentParser(description='Video Depth Anything')
30
+ parser.add_argument('--input_size', type=int, default=518)
31
+ parser.add_argument('--max_res', type=int, default=1280)
32
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
33
+ parser.add_argument('--max_len', type=int, default=-1, help='maximum length of the input video, -1 means no limit')
34
+ parser.add_argument('--target_fps', type=int, default=-1, help='target fps of the input video, -1 means the original fps')
35
+ parser.add_argument('--fp32', action='store_true', help='model infer with torch.float32, default is torch.float16')
36
+ parser.add_argument('--grayscale', action='store_true', help='do not apply colorful palette')
37
+ parser.add_argument('--save_npz', action='store_true', help='save depths as npz')
38
+ parser.add_argument('--save_exr', action='store_true', help='save depths as exr')
39
+
40
+ args = parser.parse_args()
41
+
42
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
43
+
44
+ model_configs = {
45
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
46
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
47
+ }
48
+
49
+ video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
50
+ video_depth_anything.load_state_dict(torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'), strict=True)
51
+ video_depth_anything = video_depth_anything.to(DEVICE).eval()
52
+
53
+ # place input dir and out dir here
54
+ root_img_dir = "RORD/train/img"
55
+ root_gt_dir = "RORD/train/gt"
56
+ save_root_img_base = "RORD/val/img_depth"
57
+ save_root_gt_base = "RORD/val/gt_depth"
58
+
59
+ video_ids = sorted(os.listdir(root_img_dir))
60
+
61
+ for video_id in tqdm.tqdm(video_ids):
62
+ frame_dir = os.path.join(root_img_dir, video_id)
63
+
64
+ frame_paths = sorted([
65
+ os.path.join(frame_dir, fname) for fname in os.listdir(frame_dir)
66
+ if fname.endswith(".jpg") or fname.endswith(".png")
67
+ ])
68
+ frames = [cv2.imread(p)[:, :, ::-1] for p in frame_paths]
69
+ gt_path = frame_paths[0].replace("/img/", "/gt/")
70
+
71
+ gt_img = cv2.imread(gt_path)[:, :, ::-1] # BGR to RGB
72
+ frames.append(gt_img)
73
+
74
+ resized_frames = []
75
+ max_res = 1280
76
+ for f in frames:
77
+ h, w = f.shape[:2]
78
+ if max(h, w) > max_res:
79
+ scale = max_res / max(h, w)
80
+ f = cv2.resize(f, (int(w * scale), int(h * scale)))
81
+ resized_frames.append(f)
82
+
83
+ resized_frames = np.stack(resized_frames, axis=0)
84
+
85
+ depths, _ = video_depth_anything.infer_video_depth(
86
+ resized_frames, 32, input_size=518, device=DEVICE, fp32=False
87
+ )
88
+
89
+ save_root_img = os.path.join(save_root_img_base, video_id)
90
+ save_root_gt = os.path.join(save_root_gt_base, video_id)
91
+ os.makedirs(save_root_img, exist_ok=True)
92
+ os.makedirs(save_root_gt, exist_ok=True)
93
+
94
+ colormap = np.array(cm.get_cmap("inferno").colors)
95
+ d_min, d_max = depths.min(), depths.max()
96
+ for i, path in enumerate(frame_paths):
97
+ fname = os.path.basename(path)
98
+
99
+ depth = depths[i]
100
+ depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
101
+ depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) # shape: (H, W, 3), uint8
102
+
103
+ img_path = os.path.join(save_root_img, fname)
104
+ Image.fromarray(depth_vis).save(img_path)
105
+
106
+ gt_depth = depths[-1]
107
+ gt_norm = ((gt_depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
108
+ gt_vis = (colormap[gt_norm] * 255).astype(np.uint8)
109
+
110
+ gt_save_path = os.path.join(save_root_gt, fname)
111
+ Image.fromarray(gt_vis).save(gt_save_path)
112
+
code_depth/run_single_image.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # http://www.apache.org/licenses/LICENSE-2.0
4
+
5
+ import os
6
+ import numpy as np
7
+ import torch
8
+ import cv2
9
+ import matplotlib.cm as cm
10
+ from PIL import Image
11
+ from video_depth_anything.video_depth import VideoDepthAnything
12
+
13
+ if __name__ == '__main__':
14
+
15
+ import argparse
16
+ parser = argparse.ArgumentParser(description='Video Depth Anything')
17
+ parser.add_argument('--input_size', type=int, default=518)
18
+ parser.add_argument('--max_res', type=int, default=1280)
19
+ parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitl'])
20
+ parser.add_argument('--max_len', type=int, default=-1)
21
+ parser.add_argument('--target_fps', type=int, default=-1)
22
+ parser.add_argument('--fp32', action='store_true')
23
+ parser.add_argument('--grayscale', action='store_true')
24
+ parser.add_argument('--save_npz', action='store_true')
25
+ parser.add_argument('--save_exr', action='store_true')
26
+ args = parser.parse_args()
27
+
28
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
29
+
30
+ model_configs = {
31
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
33
+ }
34
+
35
+ video_depth_anything = VideoDepthAnything(**model_configs[args.encoder])
36
+ video_depth_anything.load_state_dict(
37
+ torch.load(f'./checkpoints/video_depth_anything_{args.encoder}.pth', map_location='cpu'),
38
+ strict=True
39
+ )
40
+ video_depth_anything = video_depth_anything.to(DEVICE).eval()
41
+
42
+ # your image input and output path
43
+ input_path = ""
44
+ output_path = ""
45
+
46
+
47
+ img = cv2.imread(input_path)[:, :, ::-1]
48
+ h, w = img.shape[:2]
49
+
50
+ if max(h, w) > args.max_res:
51
+ scale = args.max_res / max(h, w)
52
+ img = cv2.resize(img, (int(w * scale), int(h * scale)))
53
+
54
+ frame_tensor = np.stack([img], axis=0)
55
+
56
+
57
+ depths, _ = video_depth_anything.infer_video_depth(
58
+ frame_tensor, 32, input_size=518, device=DEVICE, fp32=False
59
+ )
60
+ depth = depths[0]
61
+
62
+
63
+ colormap = np.array(cm.get_cmap("inferno").colors)
64
+ d_min, d_max = depth.min(), depth.max()
65
+ depth_norm = ((depth - d_min) / (d_max - d_min + 1e-6) * 255).astype(np.uint8)
66
+ depth_vis = (colormap[depth_norm] * 255).astype(np.uint8)
67
+
68
+ Image.fromarray(depth_vis).save(output_path)
69
+ print(f"Saved depth map to: {output_path}")
code_depth/utils/dc_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is originally from DepthCrafter/depthcrafter/utils.py at main · Tencent/DepthCrafter
2
+ # SPDX-License-Identifier: MIT License license
3
+ #
4
+ # This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
5
+ # Original file is released under [ MIT License license], with the full license text available at [https://github.com/Tencent/DepthCrafter?tab=License-1-ov-file].
6
+ import numpy as np
7
+ import matplotlib.cm as cm
8
+ import imageio
9
+ try:
10
+ from decord import VideoReader, cpu
11
+ DECORD_AVAILABLE = True
12
+ except:
13
+ import cv2
14
+ DECORD_AVAILABLE = False
15
+
16
+ def ensure_even(value):
17
+ return value if value % 2 == 0 else value + 1
18
+
19
+ def read_video_frames(video_path, process_length, target_fps=-1, max_res=-1):
20
+ if DECORD_AVAILABLE:
21
+ vid = VideoReader(video_path, ctx=cpu(0))
22
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
23
+ height = original_height
24
+ width = original_width
25
+ if max_res > 0 and max(height, width) > max_res:
26
+ scale = max_res / max(original_height, original_width)
27
+ height = ensure_even(round(original_height * scale))
28
+ width = ensure_even(round(original_width * scale))
29
+
30
+ vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
31
+
32
+ fps = vid.get_avg_fps() if target_fps == -1 else target_fps
33
+ stride = round(vid.get_avg_fps() / fps)
34
+ stride = max(stride, 1)
35
+ frames_idx = list(range(0, len(vid), stride))
36
+ if process_length != -1 and process_length < len(frames_idx):
37
+ frames_idx = frames_idx[:process_length]
38
+ frames = vid.get_batch(frames_idx).asnumpy()
39
+ else:
40
+ cap = cv2.VideoCapture(video_path)
41
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
42
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
44
+
45
+ if max_res > 0 and max(original_height, original_width) > max_res:
46
+ scale = max_res / max(original_height, original_width)
47
+ height = round(original_height * scale)
48
+ width = round(original_width * scale)
49
+
50
+ fps = original_fps if target_fps < 0 else target_fps
51
+
52
+ stride = max(round(original_fps / fps), 1)
53
+
54
+ frames = []
55
+ frame_count = 0
56
+ while cap.isOpened():
57
+ ret, frame = cap.read()
58
+ if not ret or (process_length > 0 and frame_count >= process_length):
59
+ break
60
+ if frame_count % stride == 0:
61
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
62
+ if max_res > 0 and max(original_height, original_width) > max_res:
63
+ frame = cv2.resize(frame, (width, height)) # Resize frame
64
+ frames.append(frame)
65
+ frame_count += 1
66
+ cap.release()
67
+ frames = np.stack(frames, axis=0)
68
+
69
+ return frames, fps
70
+
71
+
72
+ def save_video(frames, output_video_path, fps=10, is_depths=False, grayscale=False):
73
+ writer = imageio.get_writer(output_video_path, fps=fps, macro_block_size=1, codec='libx264', ffmpeg_params=['-crf', '18'])
74
+ if is_depths:
75
+ colormap = np.array(cm.get_cmap("inferno").colors)
76
+ d_min, d_max = frames.min(), frames.max()
77
+ for i in range(frames.shape[0]):
78
+ depth = frames[i]
79
+ depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
80
+ depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) if not grayscale else depth_norm
81
+ writer.append_data(depth_vis)
82
+ else:
83
+ for i in range(frames.shape[0]):
84
+ writer.append_data(frames[i])
85
+
86
+ writer.close()
code_depth/utils/util.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+
16
+ def compute_scale_and_shift(prediction, target, mask, scale_only=False):
17
+ if scale_only:
18
+ return compute_scale(prediction, target, mask), 0
19
+ else:
20
+ return compute_scale_and_shift_full(prediction, target, mask)
21
+
22
+
23
+ def compute_scale(prediction, target, mask):
24
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
25
+ prediction = prediction.astype(np.float32)
26
+ target = target.astype(np.float32)
27
+ mask = mask.astype(np.float32)
28
+
29
+ a_00 = np.sum(mask * prediction * prediction)
30
+ a_01 = np.sum(mask * prediction)
31
+ a_11 = np.sum(mask)
32
+
33
+ # right hand side: b = [b_0, b_1]
34
+ b_0 = np.sum(mask * prediction * target)
35
+
36
+ x_0 = b_0 / (a_00 + 1e-6)
37
+
38
+ return x_0
39
+
40
+ def compute_scale_and_shift_full(prediction, target, mask):
41
+ # system matrix: A = [[a_00, a_01], [a_10, a_11]]
42
+ prediction = prediction.astype(np.float32)
43
+ target = target.astype(np.float32)
44
+ mask = mask.astype(np.float32)
45
+
46
+ a_00 = np.sum(mask * prediction * prediction)
47
+ a_01 = np.sum(mask * prediction)
48
+ a_11 = np.sum(mask)
49
+
50
+ b_0 = np.sum(mask * prediction * target)
51
+ b_1 = np.sum(mask * target)
52
+
53
+ x_0 = 1
54
+ x_1 = 0
55
+
56
+ det = a_00 * a_11 - a_01 * a_01
57
+
58
+ if det != 0:
59
+ x_0 = (a_11 * b_0 - a_01 * b_1) / det
60
+ x_1 = (-a_01 * b_0 + a_00 * b_1) / det
61
+
62
+ return x_0, x_1
63
+
64
+
65
+ def get_interpolate_frames(frame_list_pre, frame_list_post):
66
+ assert len(frame_list_pre) == len(frame_list_post)
67
+ min_w = 0.0
68
+ max_w = 1.0
69
+ step = (max_w - min_w) / (len(frame_list_pre)-1)
70
+ post_w_list = [min_w] + [i * step for i in range(1,len(frame_list_pre)-1)] + [max_w]
71
+ interpolated_frames = []
72
+ for i in range(len(frame_list_pre)):
73
+ interpolated_frames.append(frame_list_pre[i] * (1-post_w_list[i]) + frame_list_post[i] * post_w_list[i])
74
+ return interpolated_frames
code_depth/video_depth_anything/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
code_depth/video_depth_anything/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
code_depth/video_depth_anything/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
code_depth/video_depth_anything/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
code_depth/video_depth_anything/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
code_depth/video_depth_anything/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
code_depth/video_depth_anything/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
code_depth/video_depth_anything/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
code_depth/video_depth_anything/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
code_depth/video_depth_anything/dpt.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from .util.blocks import FeatureFusionBlock, _make_scratch
19
+
20
+
21
+ def _make_fusion_block(features, use_bn, size=None):
22
+ return FeatureFusionBlock(
23
+ features,
24
+ nn.ReLU(False),
25
+ deconv=False,
26
+ bn=use_bn,
27
+ expand=False,
28
+ align_corners=True,
29
+ size=size,
30
+ )
31
+
32
+
33
+ class ConvBlock(nn.Module):
34
+ def __init__(self, in_feature, out_feature):
35
+ super().__init__()
36
+
37
+ self.conv_block = nn.Sequential(
38
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
39
+ nn.BatchNorm2d(out_feature),
40
+ nn.ReLU(True)
41
+ )
42
+
43
+ def forward(self, x):
44
+ return self.conv_block(x)
45
+
46
+
47
+ class DPTHead(nn.Module):
48
+ def __init__(
49
+ self,
50
+ in_channels,
51
+ features=256,
52
+ use_bn=False,
53
+ out_channels=[256, 512, 1024, 1024],
54
+ use_clstoken=False
55
+ ):
56
+ super(DPTHead, self).__init__()
57
+
58
+ self.use_clstoken = use_clstoken
59
+
60
+ self.projects = nn.ModuleList([
61
+ nn.Conv2d(
62
+ in_channels=in_channels,
63
+ out_channels=out_channel,
64
+ kernel_size=1,
65
+ stride=1,
66
+ padding=0,
67
+ ) for out_channel in out_channels
68
+ ])
69
+
70
+ self.resize_layers = nn.ModuleList([
71
+ nn.ConvTranspose2d(
72
+ in_channels=out_channels[0],
73
+ out_channels=out_channels[0],
74
+ kernel_size=4,
75
+ stride=4,
76
+ padding=0),
77
+ nn.ConvTranspose2d(
78
+ in_channels=out_channels[1],
79
+ out_channels=out_channels[1],
80
+ kernel_size=2,
81
+ stride=2,
82
+ padding=0),
83
+ nn.Identity(),
84
+ nn.Conv2d(
85
+ in_channels=out_channels[3],
86
+ out_channels=out_channels[3],
87
+ kernel_size=3,
88
+ stride=2,
89
+ padding=1)
90
+ ])
91
+
92
+ if use_clstoken:
93
+ self.readout_projects = nn.ModuleList()
94
+ for _ in range(len(self.projects)):
95
+ self.readout_projects.append(
96
+ nn.Sequential(
97
+ nn.Linear(2 * in_channels, in_channels),
98
+ nn.GELU()))
99
+
100
+ self.scratch = _make_scratch(
101
+ out_channels,
102
+ features,
103
+ groups=1,
104
+ expand=False,
105
+ )
106
+
107
+ self.scratch.stem_transpose = None
108
+
109
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
110
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
111
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
112
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
113
+
114
+ head_features_1 = features
115
+ head_features_2 = 32
116
+
117
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
118
+ self.scratch.output_conv2 = nn.Sequential(
119
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
120
+ nn.ReLU(True),
121
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
122
+ nn.ReLU(True),
123
+ nn.Identity(),
124
+ )
125
+
126
+ def forward(self, out_features, patch_h, patch_w):
127
+ out = []
128
+ for i, x in enumerate(out_features):
129
+ if self.use_clstoken:
130
+ x, cls_token = x[0], x[1]
131
+ readout = cls_token.unsqueeze(1).expand_as(x)
132
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
133
+ else:
134
+ x = x[0]
135
+
136
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
137
+
138
+ x = self.projects[i](x)
139
+ x = self.resize_layers[i](x)
140
+
141
+ out.append(x)
142
+
143
+ layer_1, layer_2, layer_3, layer_4 = out
144
+
145
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
146
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
147
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
148
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
149
+
150
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
151
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
152
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
153
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
154
+
155
+ out = self.scratch.output_conv1(path_1)
156
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
157
+ out = self.scratch.output_conv2(out)
158
+
159
+ return out
160
+
code_depth/video_depth_anything/dpt_temporal.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.nn as nn
17
+ from .dpt import DPTHead
18
+ from .motion_module.motion_module import TemporalModule
19
+ from easydict import EasyDict
20
+
21
+
22
+ class DPTHeadTemporal(DPTHead):
23
+ def __init__(self,
24
+ in_channels,
25
+ features=256,
26
+ use_bn=False,
27
+ out_channels=[256, 512, 1024, 1024],
28
+ use_clstoken=False,
29
+ num_frames=32,
30
+ pe='ape'
31
+ ):
32
+ super().__init__(in_channels, features, use_bn, out_channels, use_clstoken)
33
+
34
+ assert num_frames > 0
35
+ motion_module_kwargs = EasyDict(num_attention_heads = 8,
36
+ num_transformer_block = 1,
37
+ num_attention_blocks = 2,
38
+ temporal_max_len = num_frames,
39
+ zero_initialize = True,
40
+ pos_embedding_type = pe)
41
+
42
+ self.motion_modules = nn.ModuleList([
43
+ TemporalModule(in_channels=out_channels[2],
44
+ **motion_module_kwargs),
45
+ TemporalModule(in_channels=out_channels[3],
46
+ **motion_module_kwargs),
47
+ TemporalModule(in_channels=features,
48
+ **motion_module_kwargs),
49
+ TemporalModule(in_channels=features,
50
+ **motion_module_kwargs)
51
+ ])
52
+
53
+ def forward(self, out_features, patch_h, patch_w, frame_length, micro_batch_size=4):
54
+ out = []
55
+ for i, x in enumerate(out_features):
56
+ if self.use_clstoken:
57
+ x, cls_token = x[0], x[1]
58
+ readout = cls_token.unsqueeze(1).expand_as(x)
59
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
60
+ else:
61
+ x = x[0]
62
+
63
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)).contiguous()
64
+
65
+ B, T = x.shape[0] // frame_length, frame_length
66
+ x = self.projects[i](x)
67
+ x = self.resize_layers[i](x)
68
+
69
+ out.append(x)
70
+
71
+ layer_1, layer_2, layer_3, layer_4 = out
72
+
73
+ B, T = layer_1.shape[0] // frame_length, frame_length
74
+
75
+ layer_3 = self.motion_modules[0](layer_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
76
+ layer_4 = self.motion_modules[1](layer_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
77
+
78
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
79
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
80
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
81
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
82
+
83
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
84
+ path_4 = self.motion_modules[2](path_4.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
85
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
86
+ path_3 = self.motion_modules[3](path_3.unflatten(0, (B, T)).permute(0, 2, 1, 3, 4), None, None).permute(0, 2, 1, 3, 4).flatten(0, 1)
87
+
88
+ batch_size = layer_1_rn.shape[0]
89
+ if batch_size <= micro_batch_size or batch_size % micro_batch_size != 0:
90
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
91
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
92
+
93
+ out = self.scratch.output_conv1(path_1)
94
+ out = F.interpolate(
95
+ out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
96
+ )
97
+ ori_type = out.dtype
98
+ with torch.autocast(device_type="cuda", enabled=False):
99
+ out = self.scratch.output_conv2(out.float())
100
+ return out.to(ori_type)
101
+ else:
102
+ ret = []
103
+ for i in range(0, batch_size, micro_batch_size):
104
+ path_2 = self.scratch.refinenet2(path_3[i:i + micro_batch_size], layer_2_rn[i:i + micro_batch_size], size=layer_1_rn[i:i + micro_batch_size].shape[2:])
105
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn[i:i + micro_batch_size])
106
+ out = self.scratch.output_conv1(path_1)
107
+ out = F.interpolate(
108
+ out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True
109
+ )
110
+ ori_type = out.dtype
111
+ with torch.autocast(device_type="cuda", enabled=False):
112
+ out = self.scratch.output_conv2(out.float())
113
+ ret.append(out.to(ori_type))
114
+ return torch.cat(ret, dim=0)
code_depth/video_depth_anything/motion_module/attention.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ try:
21
+ import xformers
22
+ import xformers.ops
23
+
24
+ XFORMERS_AVAILABLE = True
25
+ except ImportError:
26
+ print("xFormers not available")
27
+ XFORMERS_AVAILABLE = False
28
+
29
+
30
+ class CrossAttention(nn.Module):
31
+ r"""
32
+ A cross attention layer.
33
+
34
+ Parameters:
35
+ query_dim (`int`): The number of channels in the query.
36
+ cross_attention_dim (`int`, *optional*):
37
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
38
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
39
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
40
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
41
+ bias (`bool`, *optional*, defaults to False):
42
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ query_dim: int,
48
+ cross_attention_dim: Optional[int] = None,
49
+ heads: int = 8,
50
+ dim_head: int = 64,
51
+ dropout: float = 0.0,
52
+ bias=False,
53
+ upcast_attention: bool = False,
54
+ upcast_softmax: bool = False,
55
+ added_kv_proj_dim: Optional[int] = None,
56
+ norm_num_groups: Optional[int] = None,
57
+ ):
58
+ super().__init__()
59
+ inner_dim = dim_head * heads
60
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
61
+ self.upcast_attention = upcast_attention
62
+ self.upcast_softmax = upcast_softmax
63
+ self.upcast_efficient_attention = False
64
+
65
+ self.scale = dim_head**-0.5
66
+
67
+ self.heads = heads
68
+ # for slice_size > 0 the attention score computation
69
+ # is split across the batch axis to save memory
70
+ # You can set slice_size with `set_attention_slice`
71
+ self.sliceable_head_dim = heads
72
+ self._slice_size = None
73
+ self._use_memory_efficient_attention_xformers = False
74
+ self.added_kv_proj_dim = added_kv_proj_dim
75
+
76
+ if norm_num_groups is not None:
77
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
78
+ else:
79
+ self.group_norm = None
80
+
81
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
82
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
83
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
84
+
85
+ if self.added_kv_proj_dim is not None:
86
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
87
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
88
+
89
+ self.to_out = nn.ModuleList([])
90
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
91
+ self.to_out.append(nn.Dropout(dropout))
92
+
93
+ def reshape_heads_to_batch_dim(self, tensor):
94
+ batch_size, seq_len, dim = tensor.shape
95
+ head_size = self.heads
96
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
97
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size).contiguous()
98
+ return tensor
99
+
100
+ def reshape_heads_to_4d(self, tensor):
101
+ batch_size, seq_len, dim = tensor.shape
102
+ head_size = self.heads
103
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size).contiguous()
104
+ return tensor
105
+
106
+ def reshape_batch_dim_to_heads(self, tensor):
107
+ batch_size, seq_len, dim = tensor.shape
108
+ head_size = self.heads
109
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim).contiguous()
110
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size).contiguous()
111
+ return tensor
112
+
113
+ def reshape_4d_to_heads(self, tensor):
114
+ batch_size, seq_len, head_size, dim = tensor.shape
115
+ head_size = self.heads
116
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size).contiguous()
117
+ return tensor
118
+
119
+ def set_attention_slice(self, slice_size):
120
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
121
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
122
+
123
+ self._slice_size = slice_size
124
+
125
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
126
+ batch_size, sequence_length, _ = hidden_states.shape
127
+
128
+ encoder_hidden_states = encoder_hidden_states
129
+
130
+ if self.group_norm is not None:
131
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = self.to_q(hidden_states)
134
+ dim = query.shape[-1]
135
+ query = self.reshape_heads_to_batch_dim(query)
136
+
137
+ if self.added_kv_proj_dim is not None:
138
+ key = self.to_k(hidden_states)
139
+ value = self.to_v(hidden_states)
140
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
141
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
142
+
143
+ key = self.reshape_heads_to_batch_dim(key)
144
+ value = self.reshape_heads_to_batch_dim(value)
145
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
146
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
147
+
148
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
149
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
150
+ else:
151
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
152
+ key = self.to_k(encoder_hidden_states)
153
+ value = self.to_v(encoder_hidden_states)
154
+
155
+ key = self.reshape_heads_to_batch_dim(key)
156
+ value = self.reshape_heads_to_batch_dim(value)
157
+
158
+ if attention_mask is not None:
159
+ if attention_mask.shape[-1] != query.shape[1]:
160
+ target_length = query.shape[1]
161
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
162
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
163
+
164
+ # attention, what we cannot get enough of
165
+ if XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers:
166
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
167
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
168
+ hidden_states = hidden_states.to(query.dtype)
169
+ else:
170
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
171
+ hidden_states = self._attention(query, key, value, attention_mask)
172
+ else:
173
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
174
+
175
+ # linear proj
176
+ hidden_states = self.to_out[0](hidden_states)
177
+
178
+ # dropout
179
+ hidden_states = self.to_out[1](hidden_states)
180
+ return hidden_states
181
+
182
+ def _attention(self, query, key, value, attention_mask=None):
183
+ if self.upcast_attention:
184
+ query = query.float()
185
+ key = key.float()
186
+
187
+ attention_scores = torch.baddbmm(
188
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
189
+ query,
190
+ key.transpose(-1, -2),
191
+ beta=0,
192
+ alpha=self.scale,
193
+ )
194
+
195
+ if attention_mask is not None:
196
+ attention_scores = attention_scores + attention_mask
197
+
198
+ if self.upcast_softmax:
199
+ attention_scores = attention_scores.float()
200
+
201
+ attention_probs = attention_scores.softmax(dim=-1)
202
+
203
+ # cast back to the original dtype
204
+ attention_probs = attention_probs.to(value.dtype)
205
+
206
+ # compute attention output
207
+ hidden_states = torch.bmm(attention_probs, value)
208
+
209
+ # reshape hidden_states
210
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
211
+ return hidden_states
212
+
213
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
214
+ batch_size_attention = query.shape[0]
215
+ hidden_states = torch.zeros(
216
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
217
+ )
218
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
219
+ for i in range(hidden_states.shape[0] // slice_size):
220
+ start_idx = i * slice_size
221
+ end_idx = (i + 1) * slice_size
222
+
223
+ query_slice = query[start_idx:end_idx]
224
+ key_slice = key[start_idx:end_idx]
225
+
226
+ if self.upcast_attention:
227
+ query_slice = query_slice.float()
228
+ key_slice = key_slice.float()
229
+
230
+ attn_slice = torch.baddbmm(
231
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
232
+ query_slice,
233
+ key_slice.transpose(-1, -2),
234
+ beta=0,
235
+ alpha=self.scale,
236
+ )
237
+
238
+ if attention_mask is not None:
239
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
240
+
241
+ if self.upcast_softmax:
242
+ attn_slice = attn_slice.float()
243
+
244
+ attn_slice = attn_slice.softmax(dim=-1)
245
+
246
+ # cast back to the original dtype
247
+ attn_slice = attn_slice.to(value.dtype)
248
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
249
+
250
+ hidden_states[start_idx:end_idx] = attn_slice
251
+
252
+ # reshape hidden_states
253
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
254
+ return hidden_states
255
+
256
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
257
+ if self.upcast_efficient_attention:
258
+ org_dtype = query.dtype
259
+ query = query.float()
260
+ key = key.float()
261
+ value = value.float()
262
+ if attention_mask is not None:
263
+ attention_mask = attention_mask.float()
264
+ hidden_states = self._memory_efficient_attention_split(query, key, value, attention_mask)
265
+
266
+ if self.upcast_efficient_attention:
267
+ hidden_states = hidden_states.to(org_dtype)
268
+
269
+ hidden_states = self.reshape_4d_to_heads(hidden_states)
270
+ return hidden_states
271
+
272
+ # print("Errror: no xformers")
273
+ # raise NotImplementedError
274
+
275
+ def _memory_efficient_attention_split(self, query, key, value, attention_mask):
276
+ batch_size = query.shape[0]
277
+ max_batch_size = 65535
278
+ num_batches = (batch_size + max_batch_size - 1) // max_batch_size
279
+ results = []
280
+ for i in range(num_batches):
281
+ start_idx = i * max_batch_size
282
+ end_idx = min((i + 1) * max_batch_size, batch_size)
283
+ query_batch = query[start_idx:end_idx]
284
+ key_batch = key[start_idx:end_idx]
285
+ value_batch = value[start_idx:end_idx]
286
+ if attention_mask is not None:
287
+ attention_mask_batch = attention_mask[start_idx:end_idx]
288
+ else:
289
+ attention_mask_batch = None
290
+ result = xformers.ops.memory_efficient_attention(query_batch, key_batch, value_batch, attn_bias=attention_mask_batch)
291
+ results.append(result)
292
+ full_result = torch.cat(results, dim=0)
293
+ return full_result
294
+
295
+
296
+ class FeedForward(nn.Module):
297
+ r"""
298
+ A feed-forward layer.
299
+
300
+ Parameters:
301
+ dim (`int`): The number of channels in the input.
302
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
303
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
304
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
305
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ dim: int,
311
+ dim_out: Optional[int] = None,
312
+ mult: int = 4,
313
+ dropout: float = 0.0,
314
+ activation_fn: str = "geglu",
315
+ ):
316
+ super().__init__()
317
+ inner_dim = int(dim * mult)
318
+ dim_out = dim_out if dim_out is not None else dim
319
+
320
+ if activation_fn == "gelu":
321
+ act_fn = GELU(dim, inner_dim)
322
+ elif activation_fn == "geglu":
323
+ act_fn = GEGLU(dim, inner_dim)
324
+ elif activation_fn == "geglu-approximate":
325
+ act_fn = ApproximateGELU(dim, inner_dim)
326
+
327
+ self.net = nn.ModuleList([])
328
+ # project in
329
+ self.net.append(act_fn)
330
+ # project dropout
331
+ self.net.append(nn.Dropout(dropout))
332
+ # project out
333
+ self.net.append(nn.Linear(inner_dim, dim_out))
334
+
335
+ def forward(self, hidden_states):
336
+ for module in self.net:
337
+ hidden_states = module(hidden_states)
338
+ return hidden_states
339
+
340
+
341
+ class GELU(nn.Module):
342
+ r"""
343
+ GELU activation function
344
+ """
345
+
346
+ def __init__(self, dim_in: int, dim_out: int):
347
+ super().__init__()
348
+ self.proj = nn.Linear(dim_in, dim_out)
349
+
350
+ def gelu(self, gate):
351
+ if gate.device.type != "mps":
352
+ return F.gelu(gate)
353
+ # mps: gelu is not implemented for float16
354
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
355
+
356
+ def forward(self, hidden_states):
357
+ hidden_states = self.proj(hidden_states)
358
+ hidden_states = self.gelu(hidden_states)
359
+ return hidden_states
360
+
361
+
362
+ # feedforward
363
+ class GEGLU(nn.Module):
364
+ r"""
365
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
366
+
367
+ Parameters:
368
+ dim_in (`int`): The number of channels in the input.
369
+ dim_out (`int`): The number of channels in the output.
370
+ """
371
+
372
+ def __init__(self, dim_in: int, dim_out: int):
373
+ super().__init__()
374
+ self.proj = nn.Linear(dim_in, dim_out * 2)
375
+
376
+ def gelu(self, gate):
377
+ if gate.device.type != "mps":
378
+ return F.gelu(gate)
379
+ # mps: gelu is not implemented for float16
380
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
381
+
382
+ def forward(self, hidden_states):
383
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
384
+ return hidden_states * self.gelu(gate)
385
+
386
+
387
+ class ApproximateGELU(nn.Module):
388
+ """
389
+ The approximate form of Gaussian Error Linear Unit (GELU)
390
+
391
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
392
+ """
393
+
394
+ def __init__(self, dim_in: int, dim_out: int):
395
+ super().__init__()
396
+ self.proj = nn.Linear(dim_in, dim_out)
397
+
398
+ def forward(self, x):
399
+ x = self.proj(x)
400
+ return x * torch.sigmoid(1.702 * x)
401
+
402
+
403
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
404
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
405
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
406
+ freqs = torch.outer(t, freqs)
407
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
408
+ return freqs_cis
409
+
410
+
411
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
412
+ ndim = x.ndim
413
+ assert 0 <= 1 < ndim
414
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
415
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
416
+ return freqs_cis.view(*shape)
417
+
418
+
419
+ def apply_rotary_emb(
420
+ xq: torch.Tensor,
421
+ xk: torch.Tensor,
422
+ freqs_cis: torch.Tensor,
423
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
424
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2).contiguous())
425
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2).contiguous())
426
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
427
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
428
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
429
+ return xq_out.type_as(xq), xk_out.type_as(xk)
code_depth/video_depth_anything/motion_module/motion_module.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is originally from AnimateDiff/animatediff/models/motion_module.py at main · guoyww/AnimateDiff
2
+ # SPDX-License-Identifier: Apache-2.0 license
3
+ #
4
+ # This file may have been modified by ByteDance Ltd. and/or its affiliates on [date of modification]
5
+ # Original file was released under [ Apache-2.0 license], with the full license text available at [https://github.com/guoyww/AnimateDiff?tab=Apache-2.0-1-ov-file#readme].
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from .attention import CrossAttention, FeedForward, apply_rotary_emb, precompute_freqs_cis
11
+
12
+ from einops import rearrange, repeat
13
+ import math
14
+
15
+ try:
16
+ import xformers
17
+ import xformers.ops
18
+
19
+ XFORMERS_AVAILABLE = True
20
+ except ImportError:
21
+ print("xFormers not available")
22
+ XFORMERS_AVAILABLE = False
23
+
24
+
25
+ def zero_module(module):
26
+ # Zero out the parameters of a module and return it.
27
+ for p in module.parameters():
28
+ p.detach().zero_()
29
+ return module
30
+
31
+
32
+ class TemporalModule(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ num_attention_heads = 8,
37
+ num_transformer_block = 2,
38
+ num_attention_blocks = 2,
39
+ norm_num_groups = 32,
40
+ temporal_max_len = 32,
41
+ zero_initialize = True,
42
+ pos_embedding_type = "ape",
43
+ ):
44
+ super().__init__()
45
+
46
+ self.temporal_transformer = TemporalTransformer3DModel(
47
+ in_channels=in_channels,
48
+ num_attention_heads=num_attention_heads,
49
+ attention_head_dim=in_channels // num_attention_heads,
50
+ num_layers=num_transformer_block,
51
+ num_attention_blocks=num_attention_blocks,
52
+ norm_num_groups=norm_num_groups,
53
+ temporal_max_len=temporal_max_len,
54
+ pos_embedding_type=pos_embedding_type,
55
+ )
56
+
57
+ if zero_initialize:
58
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
59
+
60
+ def forward(self, input_tensor, encoder_hidden_states, attention_mask=None):
61
+ hidden_states = input_tensor
62
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
63
+
64
+ output = hidden_states
65
+ return output
66
+
67
+
68
+ class TemporalTransformer3DModel(nn.Module):
69
+ def __init__(
70
+ self,
71
+ in_channels,
72
+ num_attention_heads,
73
+ attention_head_dim,
74
+ num_layers,
75
+ num_attention_blocks = 2,
76
+ norm_num_groups = 32,
77
+ temporal_max_len = 32,
78
+ pos_embedding_type = "ape",
79
+ ):
80
+ super().__init__()
81
+
82
+ inner_dim = num_attention_heads * attention_head_dim
83
+
84
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85
+ self.proj_in = nn.Linear(in_channels, inner_dim)
86
+
87
+ self.transformer_blocks = nn.ModuleList(
88
+ [
89
+ TemporalTransformerBlock(
90
+ dim=inner_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ attention_head_dim=attention_head_dim,
93
+ num_attention_blocks=num_attention_blocks,
94
+ temporal_max_len=temporal_max_len,
95
+ pos_embedding_type=pos_embedding_type,
96
+ )
97
+ for d in range(num_layers)
98
+ ]
99
+ )
100
+ self.proj_out = nn.Linear(inner_dim, in_channels)
101
+
102
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
103
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
104
+ video_length = hidden_states.shape[2]
105
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
106
+
107
+ batch, channel, height, width = hidden_states.shape
108
+ residual = hidden_states
109
+
110
+ hidden_states = self.norm(hidden_states)
111
+ inner_dim = hidden_states.shape[1]
112
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim).contiguous()
113
+ hidden_states = self.proj_in(hidden_states)
114
+
115
+ # Transformer Blocks
116
+ for block in self.transformer_blocks:
117
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, attention_mask=attention_mask)
118
+
119
+ # output
120
+ hidden_states = self.proj_out(hidden_states)
121
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
122
+
123
+ output = hidden_states + residual
124
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
125
+
126
+ return output
127
+
128
+
129
+ class TemporalTransformerBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim,
133
+ num_attention_heads,
134
+ attention_head_dim,
135
+ num_attention_blocks = 2,
136
+ temporal_max_len = 32,
137
+ pos_embedding_type = "ape",
138
+ ):
139
+ super().__init__()
140
+
141
+ self.attention_blocks = nn.ModuleList(
142
+ [
143
+ TemporalAttention(
144
+ query_dim=dim,
145
+ heads=num_attention_heads,
146
+ dim_head=attention_head_dim,
147
+ temporal_max_len=temporal_max_len,
148
+ pos_embedding_type=pos_embedding_type,
149
+ )
150
+ for i in range(num_attention_blocks)
151
+ ]
152
+ )
153
+ self.norms = nn.ModuleList(
154
+ [
155
+ nn.LayerNorm(dim)
156
+ for i in range(num_attention_blocks)
157
+ ]
158
+ )
159
+
160
+ self.ff = FeedForward(dim, dropout=0.0, activation_fn="geglu")
161
+ self.ff_norm = nn.LayerNorm(dim)
162
+
163
+
164
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
165
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
166
+ norm_hidden_states = norm(hidden_states)
167
+ hidden_states = attention_block(
168
+ norm_hidden_states,
169
+ encoder_hidden_states=encoder_hidden_states,
170
+ video_length=video_length,
171
+ attention_mask=attention_mask,
172
+ ) + hidden_states
173
+
174
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
175
+
176
+ output = hidden_states
177
+ return output
178
+
179
+
180
+ class PositionalEncoding(nn.Module):
181
+ def __init__(
182
+ self,
183
+ d_model,
184
+ dropout = 0.,
185
+ max_len = 32
186
+ ):
187
+ super().__init__()
188
+ self.dropout = nn.Dropout(p=dropout)
189
+ position = torch.arange(max_len).unsqueeze(1)
190
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
191
+ pe = torch.zeros(1, max_len, d_model)
192
+ pe[0, :, 0::2] = torch.sin(position * div_term)
193
+ pe[0, :, 1::2] = torch.cos(position * div_term)
194
+ self.register_buffer('pe', pe)
195
+
196
+ def forward(self, x):
197
+ x = x + self.pe[:, :x.size(1)].to(x.dtype)
198
+ return self.dropout(x)
199
+
200
+ class TemporalAttention(CrossAttention):
201
+ def __init__(
202
+ self,
203
+ temporal_max_len = 32,
204
+ pos_embedding_type = "ape",
205
+ *args, **kwargs
206
+ ):
207
+ super().__init__(*args, **kwargs)
208
+
209
+ self.pos_embedding_type = pos_embedding_type
210
+ self._use_memory_efficient_attention_xformers = True
211
+
212
+ self.pos_encoder = None
213
+ self.freqs_cis = None
214
+ if self.pos_embedding_type == "ape":
215
+ self.pos_encoder = PositionalEncoding(
216
+ kwargs["query_dim"],
217
+ dropout=0.,
218
+ max_len=temporal_max_len
219
+ )
220
+
221
+ elif self.pos_embedding_type == "rope":
222
+ self.freqs_cis = precompute_freqs_cis(
223
+ kwargs["query_dim"],
224
+ temporal_max_len
225
+ )
226
+
227
+ else:
228
+ raise NotImplementedError
229
+
230
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
231
+ d = hidden_states.shape[1]
232
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
233
+
234
+ if self.pos_encoder is not None:
235
+ hidden_states = self.pos_encoder(hidden_states)
236
+
237
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
238
+
239
+ if self.group_norm is not None:
240
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
241
+
242
+ query = self.to_q(hidden_states)
243
+ dim = query.shape[-1]
244
+
245
+ if self.added_kv_proj_dim is not None:
246
+ raise NotImplementedError
247
+
248
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
249
+ key = self.to_k(encoder_hidden_states)
250
+ value = self.to_v(encoder_hidden_states)
251
+
252
+ if self.freqs_cis is not None:
253
+ seq_len = query.shape[1]
254
+ freqs_cis = self.freqs_cis[:seq_len].to(query.device)
255
+ query, key = apply_rotary_emb(query, key, freqs_cis)
256
+
257
+ if attention_mask is not None:
258
+ if attention_mask.shape[-1] != query.shape[1]:
259
+ target_length = query.shape[1]
260
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
261
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
262
+
263
+
264
+ use_memory_efficient = XFORMERS_AVAILABLE and self._use_memory_efficient_attention_xformers
265
+ if use_memory_efficient and (dim // self.heads) % 8 != 0:
266
+ # print('Warning: the dim {} cannot be divided by 8. Fall into normal attention'.format(dim // self.heads))
267
+ use_memory_efficient = False
268
+
269
+ # attention, what we cannot get enough of
270
+ if use_memory_efficient:
271
+ query = self.reshape_heads_to_4d(query)
272
+ key = self.reshape_heads_to_4d(key)
273
+ value = self.reshape_heads_to_4d(value)
274
+
275
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
276
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
277
+ hidden_states = hidden_states.to(query.dtype)
278
+ else:
279
+ query = self.reshape_heads_to_batch_dim(query)
280
+ key = self.reshape_heads_to_batch_dim(key)
281
+ value = self.reshape_heads_to_batch_dim(value)
282
+
283
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
284
+ hidden_states = self._attention(query, key, value, attention_mask)
285
+ else:
286
+ raise NotImplementedError
287
+ # hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
288
+
289
+ # linear proj
290
+ hidden_states = self.to_out[0](hidden_states)
291
+
292
+ # dropout
293
+ hidden_states = self.to_out[1](hidden_states)
294
+
295
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
296
+
297
+ return hidden_states
code_depth/video_depth_anything/util/blocks.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(
21
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
22
+ )
23
+ scratch.layer2_rn = nn.Conv2d(
24
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
25
+ )
26
+ scratch.layer3_rn = nn.Conv2d(
27
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
28
+ )
29
+ if len(in_shape) >= 4:
30
+ scratch.layer4_rn = nn.Conv2d(
31
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
32
+ )
33
+
34
+ return scratch
35
+
36
+
37
+ class ResidualConvUnit(nn.Module):
38
+ """Residual convolution module."""
39
+
40
+ def __init__(self, features, activation, bn):
41
+ """Init.
42
+
43
+ Args:
44
+ features (int): number of features
45
+ """
46
+ super().__init__()
47
+
48
+ self.bn = bn
49
+
50
+ self.groups = 1
51
+
52
+ self.conv1 = nn.Conv2d(
53
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
54
+ )
55
+
56
+ self.conv2 = nn.Conv2d(
57
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
58
+ )
59
+
60
+ if self.bn is True:
61
+ self.bn1 = nn.BatchNorm2d(features)
62
+ self.bn2 = nn.BatchNorm2d(features)
63
+
64
+ self.activation = activation
65
+
66
+ self.skip_add = nn.quantized.FloatFunctional()
67
+
68
+ def forward(self, x):
69
+ """Forward pass.
70
+
71
+ Args:
72
+ x (tensor): input
73
+
74
+ Returns:
75
+ tensor: output
76
+ """
77
+
78
+ out = self.activation(x)
79
+ out = self.conv1(out)
80
+ if self.bn is True:
81
+ out = self.bn1(out)
82
+
83
+ out = self.activation(out)
84
+ out = self.conv2(out)
85
+ if self.bn is True:
86
+ out = self.bn2(out)
87
+
88
+ if self.groups > 1:
89
+ out = self.conv_merge(out)
90
+
91
+ return self.skip_add.add(out, x)
92
+
93
+
94
+ class FeatureFusionBlock(nn.Module):
95
+ """Feature fusion block."""
96
+
97
+ def __init__(
98
+ self,
99
+ features,
100
+ activation,
101
+ deconv=False,
102
+ bn=False,
103
+ expand=False,
104
+ align_corners=True,
105
+ size=None,
106
+ ):
107
+ """Init.
108
+
109
+ Args:
110
+ features (int): number of features
111
+ """
112
+ super().__init__()
113
+
114
+ self.deconv = deconv
115
+ self.align_corners = align_corners
116
+
117
+ self.groups = 1
118
+
119
+ self.expand = expand
120
+ out_features = features
121
+ if self.expand is True:
122
+ out_features = features // 2
123
+
124
+ self.out_conv = nn.Conv2d(
125
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1
126
+ )
127
+
128
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
129
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
130
+
131
+ self.skip_add = nn.quantized.FloatFunctional()
132
+
133
+ self.size = size
134
+
135
+ def forward(self, *xs, size=None):
136
+ """Forward pass.
137
+
138
+ Returns:
139
+ tensor: output
140
+ """
141
+ output = xs[0]
142
+
143
+ if len(xs) == 2:
144
+ res = self.resConfUnit1(xs[1])
145
+ output = self.skip_add.add(output, res)
146
+
147
+ output = self.resConfUnit2(output)
148
+
149
+ if (size is None) and (self.size is None):
150
+ modifier = {"scale_factor": 2}
151
+ elif size is None:
152
+ modifier = {"size": self.size}
153
+ else:
154
+ modifier = {"size": size}
155
+
156
+ output = nn.functional.interpolate(
157
+ output.contiguous(), **modifier, mode="bilinear", align_corners=self.align_corners
158
+ )
159
+
160
+ output = self.out_conv(output)
161
+
162
+ return output
code_depth/video_depth_anything/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample
code_depth/video_depth_anything/video_depth.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2025) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.nn as nn
17
+ from torchvision.transforms import Compose
18
+ import cv2
19
+ from tqdm import tqdm
20
+ import numpy as np
21
+ import gc
22
+
23
+ from .dinov2 import DINOv2
24
+ from .dpt_temporal import DPTHeadTemporal
25
+ from .util.transform import Resize, NormalizeImage, PrepareForNet
26
+
27
+ from utils.util import compute_scale_and_shift, get_interpolate_frames
28
+
29
+ # infer settings, do not change
30
+ INFER_LEN = 32
31
+ OVERLAP = 10
32
+ KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
33
+ INTERP_LEN = 8
34
+
35
+ class VideoDepthAnything(nn.Module):
36
+ def __init__(
37
+ self,
38
+ encoder='vitl',
39
+ features=256,
40
+ out_channels=[256, 512, 1024, 1024],
41
+ use_bn=False,
42
+ use_clstoken=False,
43
+ num_frames=32,
44
+ pe='ape'
45
+ ):
46
+ super(VideoDepthAnything, self).__init__()
47
+
48
+ self.intermediate_layer_idx = {
49
+ 'vits': [2, 5, 8, 11],
50
+ 'vitl': [4, 11, 17, 23]
51
+ }
52
+
53
+ self.encoder = encoder
54
+ self.pretrained = DINOv2(model_name=encoder)
55
+
56
+ self.head = DPTHeadTemporal(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken, num_frames=num_frames, pe=pe)
57
+
58
+ def forward(self, x):
59
+ B, T, C, H, W = x.shape
60
+ patch_h, patch_w = H // 14, W // 14
61
+ features = self.pretrained.get_intermediate_layers(x.flatten(0,1), self.intermediate_layer_idx[self.encoder], return_class_token=True)
62
+ depth = self.head(features, patch_h, patch_w, T)
63
+ depth = F.interpolate(depth, size=(H, W), mode="bilinear", align_corners=True)
64
+ depth = F.relu(depth)
65
+ return depth.squeeze(1).unflatten(0, (B, T)) # return shape [B, T, H, W]
66
+
67
+ def infer_video_depth(self, frames, target_fps, input_size=518, device='cuda', fp32=False):
68
+ frame_height, frame_width = frames[0].shape[:2]
69
+ ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
70
+ if ratio > 1.78: # we recommend to process video with ratio smaller than 16:9 due to memory limitation
71
+ input_size = int(input_size * 1.777 / ratio)
72
+ input_size = round(input_size / 14) * 14
73
+
74
+ transform = Compose([
75
+ Resize(
76
+ width=input_size,
77
+ height=input_size,
78
+ resize_target=False,
79
+ keep_aspect_ratio=True,
80
+ ensure_multiple_of=14,
81
+ resize_method='lower_bound',
82
+ image_interpolation_method=cv2.INTER_CUBIC,
83
+ ),
84
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
85
+ PrepareForNet(),
86
+ ])
87
+
88
+ frame_list = [frames[i] for i in range(frames.shape[0])]
89
+ frame_step = INFER_LEN - OVERLAP
90
+ org_video_len = len(frame_list)
91
+ append_frame_len = (frame_step - (org_video_len % frame_step)) % frame_step + (INFER_LEN - frame_step)
92
+ frame_list = frame_list + [frame_list[-1].copy()] * append_frame_len
93
+
94
+ depth_list = []
95
+ pre_input = None
96
+ for frame_id in tqdm(range(0, org_video_len, frame_step)):
97
+ cur_list = []
98
+ for i in range(INFER_LEN):
99
+ cur_list.append(torch.from_numpy(transform({'image': frame_list[frame_id+i].astype(np.float32) / 255.0})['image']).unsqueeze(0).unsqueeze(0))
100
+ cur_input = torch.cat(cur_list, dim=1).to(device)
101
+ if pre_input is not None:
102
+ cur_input[:, :OVERLAP, ...] = pre_input[:, KEYFRAMES, ...]
103
+
104
+ with torch.no_grad():
105
+ with torch.autocast(device_type=device, enabled=(not fp32)):
106
+ depth = self.forward(cur_input) # depth shape: [1, T, H, W]
107
+
108
+ depth = depth.to(cur_input.dtype)
109
+ depth = F.interpolate(depth.flatten(0,1).unsqueeze(1), size=(frame_height, frame_width), mode='bilinear', align_corners=True)
110
+ depth_list += [depth[i][0].cpu().numpy() for i in range(depth.shape[0])]
111
+
112
+ pre_input = cur_input
113
+
114
+ del frame_list
115
+ gc.collect()
116
+
117
+ depth_list_aligned = []
118
+ ref_align = []
119
+ align_len = OVERLAP - INTERP_LEN
120
+ kf_align_list = KEYFRAMES[:align_len]
121
+
122
+ for frame_id in range(0, len(depth_list), INFER_LEN):
123
+ if len(depth_list_aligned) == 0:
124
+ depth_list_aligned += depth_list[:INFER_LEN]
125
+ for kf_id in kf_align_list:
126
+ ref_align.append(depth_list[frame_id+kf_id])
127
+ else:
128
+ curr_align = []
129
+ for i in range(len(kf_align_list)):
130
+ curr_align.append(depth_list[frame_id+i])
131
+ scale, shift = compute_scale_and_shift(np.concatenate(curr_align),
132
+ np.concatenate(ref_align),
133
+ np.concatenate(np.ones_like(ref_align)==1))
134
+
135
+ pre_depth_list = depth_list_aligned[-INTERP_LEN:]
136
+ post_depth_list = depth_list[frame_id+align_len:frame_id+OVERLAP]
137
+ for i in range(len(post_depth_list)):
138
+ post_depth_list[i] = post_depth_list[i] * scale + shift
139
+ post_depth_list[i][post_depth_list[i]<0] = 0
140
+ depth_list_aligned[-INTERP_LEN:] = get_interpolate_frames(pre_depth_list, post_depth_list)
141
+
142
+ for i in range(OVERLAP, INFER_LEN):
143
+ new_depth = depth_list[frame_id+i] * scale + shift
144
+ new_depth[new_depth<0] = 0
145
+ depth_list_aligned.append(new_depth)
146
+
147
+ ref_align = ref_align[:1]
148
+ for kf_id in kf_align_list[1:]:
149
+ new_depth = depth_list[frame_id+kf_id] * scale + shift
150
+ new_depth[new_depth<0] = 0
151
+ ref_align.append(new_depth)
152
+
153
+ depth_list = depth_list_aligned
154
+
155
+ return np.stack(depth_list[:org_video_len], axis=0), target_fps
156
+
code_edit/.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
code_edit/Flux_fill_d2i.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import FluxFillPipeline_token12_depth as FluxFillPipeline
3
+ from diffusers.utils import load_image
4
+ import os, glob
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image
8
+
9
+ image_path = ["example_data/I-210618_I01001_W01_I-210618_I01001_W01_F0153_img.jpg"]
10
+
11
+ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
12
+ pipe.load_lora_weights("stage2/checkpoint-20000")
13
+ for image_ep in image_path:
14
+ image = Image.open(image_ep)
15
+ mask = Image.new("L", image.size, 0) # place_hold
16
+ depth_path = image_ep.replace("_img.jpg", "_depth_img.png")
17
+ depth_image = Image.open(depth_path)
18
+ depth = Image.open(depth_path.replace("_img", "_img_fill_in"))
19
+ image_name = os.path.basename(image_ep)
20
+
21
+ orig_w, orig_h = image.size
22
+ w, h = image.size
23
+ MAX_SIZE = 1024
24
+ if max(w, h) > MAX_SIZE:
25
+ factor = MAX_SIZE / max(w, h)
26
+ w = int(factor * w)
27
+ h = int(factor * h)
28
+ width, height = map(lambda x: x - x % 64, (w, h))
29
+ # # Resize to 1024 × 1024
30
+ target_size = (width, height)
31
+ # target_size = (1024, 1024)
32
+ # image_resized = image.resize(target_size, Image.BICUBIC)
33
+ # mask_resized = mask.resize(target_size, Image.NEAREST)
34
+ # depth_resized = depth.resize(target_size, Image.BICUBIC)
35
+ # depth_image_resized = depth_image.resize(target_size, Image.BICUBIC)
36
+
37
+ image = pipe(
38
+ prompt="A beautiful scene",
39
+ image=image,
40
+ mask_image=mask,
41
+ width=target_size[0],
42
+ height=target_size[1],
43
+ guidance_scale=30,
44
+ num_inference_steps=50,
45
+ max_sequence_length=512,
46
+ generator=torch.Generator("cpu").manual_seed(0),
47
+ depth=depth,
48
+ depth_image=depth_image,
49
+ ).images[0]
50
+ image_final = image.resize((orig_w * 3, orig_h), Image.BICUBIC)
51
+ output_dir = "./test_images/"
52
+ os.makedirs(output_dir, exist_ok=True)
53
+ image_final.save(os.path.join(output_dir,image_name))
code_edit/Flux_fill_infer_depth.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.pipelines.flux.pipeline_flux_fill_unmasked_image_condition_version import FluxFillPipeline_token12_depth_only as FluxFillPipeline
3
+ from diffusers.utils import load_image
4
+ import os, glob
5
+ import numpy as np
6
+ import cv2
7
+ from PIL import Image, ImageOps
8
+
9
+ image_path = ["example_data/I-210618_I01001_W01_I-210618_I01001_W01_F0153_img.jpg"]
10
+ pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
11
+ pipe.load_lora_weights("stage1/checkpoint-4800")
12
+ for image_ep in image_path:
13
+ mask_path = image_ep.replace("_img.jpg","_mask.png")
14
+ image = Image.open(image_ep) # place_hold
15
+ depth = Image.open(image_ep.replace("_img.jpg",
16
+ "_depth_img.png"))
17
+ image_name = os.path.basename(image_ep)
18
+ mask = Image.open(mask_path).convert("L")
19
+ mask = ImageOps.invert(mask) # inverse rord_mask
20
+
21
+ # mask_np = np.array(mask)
22
+
23
+ # # mask dilation
24
+ # dilation_px = 32
25
+ # kernel = np.ones((3, 3), np.uint8)
26
+ # iterations = dilation_px // 2
27
+ # dilated_mask = cv2.dilate(mask_np, kernel, iterations=iterations)
28
+ # mask = Image.fromarray(dilated_mask)
29
+
30
+ orig_w, orig_h = image.size
31
+
32
+ # Resize to 1024 × 1024
33
+ # target_size = (1024, 1024)
34
+ # image_resized = image.resize(target_size, Image.BICUBIC)
35
+ # mask_resized = mask.resize(target_size, Image.NEAREST)
36
+ # depth_resized = depth.resize(target_size, Image.BICUBIC)
37
+
38
+ w, h = image.size
39
+ MAX_SIZE = 1024
40
+ if max(w, h) > MAX_SIZE:
41
+ factor = MAX_SIZE / max(w, h)
42
+ w = int(factor * w)
43
+ h = int(factor * h)
44
+ width, height = map(lambda x: x - x % 64, (w, h))
45
+ image_out = pipe(
46
+ prompt="A beautiful scene",
47
+ image=image,
48
+ mask_image=mask,
49
+ width=width,
50
+ height=height,
51
+ guidance_scale=30,
52
+ num_inference_steps=50,
53
+ max_sequence_length=512,
54
+ generator=torch.Generator("cpu").manual_seed(0),
55
+ depth=depth
56
+ ).images[0]
57
+
58
+
59
+ image_final = image_out.resize((orig_w, orig_h), Image.BICUBIC)
60
+
61
+ output_dir = "./depth_fillin_results"
62
+ os.makedirs(output_dir, exist_ok=True)
63
+ image_final.save(os.path.join(output_dir, image_name))
64
+
code_edit/README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The official implementation of the **NeurIPS 2025** paper:
2
+
3
+ <div align="center">
4
+ <h1>
5
+ <b>
6
+ GeoRemover: Removing Objects and Their Causal Visual Artifacts, NeurIPS, 2025 (Spotlight)
7
+ </b>
8
+ </h1>
9
+ </div>
10
+
11
+ <p align="center"><img src="docs/teaser.png" width="800"/></p>
12
+
13
+ > [**GeoRemover: Removing Objects and Their Causal Visual Artifacts**](https://arxiv.org/abs/2509.18538)
14
+ >
15
+ > Zixin Zhu, Haoxiang Li, Xuelu Feng, He Wu, Chunming Qiao, Junsong Yuan
16
+
17
+ > **Abstract:** *Towards intelligent image editing, object removal should eliminate both the target object and its causal visual artifacts, such as shadows and reflections. However, existing image appearance-based methods either follow strictly mask-aligned training and fail to remove these casual effects which are not explicitly masked, or adopt loosely mask-aligned strategies that lack controllability and may unintentionally over-erase other objects. We identify that these limitations stem from ignoring the causal relationship between an object’s geometry presence and its visual effects. To address this limitation, we propose a geometry-aware two-stage framework that decouples object removal into (1) geometry removal and (2) appearance rendering. In the first stage, we remove the object directly from the geometry (e.g., depth) using strictly mask-aligned supervision, enabling structure-aware editing with strong geometric constraints. In the second stage, we render a photorealistic RGB image conditioned on the updated geometry, where causal visual effects are considered implicitly as a result of the modified 3D geometry. To guide learning in the geometry removal stage, we introduce a preference-driven objective based on positive and negative sample pairs, encouraging the model to remove objects as well as their causal visual artifacts while avoiding new structural insertions. Extensive experiments demonstrate that our method achieves state-of-the-art performance in removing both objects and their associated artifacts on two popular benchmarks.*
18
+
19
+ ### Installing the dependencies
20
+
21
+ Before running the scripts, make sure to install the library's training dependencies:
22
+
23
+ **Important**
24
+
25
+ ```bash
26
+ bash env.sh
27
+ ```
28
+
29
+ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
30
+
31
+ ```bash
32
+ accelerate config
33
+ ```
34
+
35
+ Or for a default accelerate configuration without answering questions about your environment
36
+
37
+ ```bash
38
+ accelerate config default
39
+ ```
40
+
41
+ ### Data prepare
42
+ Download the images on [RORD](https://github.com/Forty-lock/RORD) and generate depth maps with [Video-Depth-Anythingv2](https://github.com/DepthAnything/Video-Depth-Anything). (The code for VideoDepthAnything v2 can be found in the same repository, on the `depth` branch, using the [script](https://github.com/buxiangzhiren/GeoRemover/blob/depth/run_images_rord.py))
43
+
44
+ ### Training
45
+ You should build your own *train_images_and_rord_masks.csv* first. The file in the repo is not the full RORD—it's just an example.
46
+
47
+ For stage1:geometry removal
48
+ ```bash
49
+ bash train_stage1.sh
50
+ ```
51
+ For stage2:appearance rendering
52
+ ```bash
53
+ bash train_stage2.sh
54
+ ```
55
+ ### Inference
56
+ First, use https://github.com/buxiangzhiren/GeoRemover/blob/depth/run_single_image.py to get the depth of a image
57
+
58
+ For stage1:geometry removal
59
+ ```bash
60
+ python Flux_fill_infer_depth.py
61
+ ```
62
+ For stage2:appearance rendering
63
+ ```bash
64
+ python Flux_fill_d2i.py
65
+ ```
66
+ ### Checkpoints
67
+ Hugging Face:
68
+ [stage1:geometry removal and stage2:appearance rendering](https://huggingface.co/buxiangzhiren/GeoRemover)
69
+
70
+
71
+ Google drive:
72
+ [stage1:geometry removal](https://drive.google.com/file/d/1y6vnxqnFTiO6sxoKDBkvFbAeniHFka89/view?usp=sharing)
73
+ and [stage2:appearance rendering](https://drive.google.com/file/d/1U8rp1hqOswQB-0T0fh2aDQu-o1GLfd6E/view?usp=sharing)
74
+
75
+
76
+ ### Acknowledgement
77
+
78
+ This repo is based on [RORD](https://github.com/Forty-lock/RORD), [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) and [Video-Depth-Anythingv2](https://github.com/DepthAnything/Video-Depth-Anything). Thanks for their wonderful works.
79
+
80
+
81
+ ### Citation
82
+
83
+ ```
84
+ @misc{zhu2025georemoverremovingobjectscausal,
85
+ title={GeoRemover: Removing Objects and Their Causal Visual Artifacts},
86
+ author={Zixin Zhu and Haoxiang Li and Xuelu Feng and He Wu and Chunming Qiao and Junsong Yuan},
87
+ year={2025},
88
+ eprint={2509.18538},
89
+ archivePrefix={arXiv},
90
+ primaryClass={cs.CV},
91
+ url={https://arxiv.org/abs/2509.18538},
92
+ }
93
+ ```