toshas commited on
Commit
2c83504
1 Parent(s): 0b8151e

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea
2
+ .DS_Store
README.md CHANGED
@@ -1,13 +1,24 @@
1
  ---
2
- title: Marigold
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.8.0
8
  app_file: app.py
9
- pinned: false
10
  license: cc-by-sa-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Marigold Depth Estimation
3
+ emoji: 🏵️
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.44.4
8
  app_file: app.py
9
+ pinned: true
10
  license: cc-by-sa-4.0
11
  ---
12
 
13
+ This is a demo of the monocular depth estimation pipeline, described in the paper titled ["Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation"](https://arxiv.org/abs/2312.02145)
14
+
15
+ ```
16
+ @misc{ke2023repurposing,
17
+ title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
18
+ author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
19
+ year={2023},
20
+ eprint={2312.02145},
21
+ archivePrefix={arXiv},
22
+ primaryClass={cs.CV}
23
+ }
24
+ ```
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ import gradio as gr
5
+
6
+
7
+ desc = """
8
+ <p align="center">
9
+ <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
10
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
11
+ </a>
12
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
13
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
14
+ </a>
15
+ <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
16
+ <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
17
+ </a>
18
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
19
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
20
+ </a>
21
+ </p>
22
+ <p align="justify">
23
+ Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.
24
+ </p>
25
+ """
26
+
27
+
28
+ def init_persistence(purge=False):
29
+ if not os.path.exists('/data'):
30
+ return
31
+ os.environ['ckpt_dir'] = "/data/Marigold_ckpt"
32
+ os.environ['TRANSFORMERS_CACHE'] = "/data/hfcache"
33
+ os.environ['HF_DATASETS_CACHE'] = "/data/hfcache"
34
+ os.environ['HF_HOME'] = "/data/hfcache"
35
+ if purge:
36
+ os.system("rm -rf /data/Marigold_ckpt/*")
37
+
38
+
39
+ def download_code_weights():
40
+ os.system('git clone https://github.com/prs-eth/Marigold.git')
41
+ os.system('cd Marigold && bash script/download_weights.sh')
42
+ os.system('echo /data && ls -la /data')
43
+ os.system('echo /data/Marigold_ckpt && ls -la /data/Marigold_ckpt')
44
+ os.system('echo /data/Marigold_ckpt/Marigold_v1_merged && ls -la /data/Marigold_ckpt/Marigold_v1_merged')
45
+
46
+
47
+ def find_first_png(directory):
48
+ for file in os.listdir(directory):
49
+ if file.lower().endswith(".png"):
50
+ return os.path.join(directory, file)
51
+ return None
52
+
53
+
54
+ def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None):
55
+ if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None:
56
+ return path_out_png, path_out_obj, path_out_2_png
57
+
58
+ path_input_dir = path_input + ".input"
59
+ path_output_dir = path_input + ".output"
60
+ os.makedirs(path_input_dir, exist_ok=True)
61
+ os.makedirs(path_output_dir, exist_ok=True)
62
+ shutil.copy(path_input, path_input_dir)
63
+
64
+ persistence_args = ""
65
+ if os.path.exists('/data'):
66
+ persistence_args = "--checkpoint /data/Marigold_ckpt/Marigold_v1_merged"
67
+
68
+ os.system(
69
+ f"cd Marigold && python3 run.py "
70
+ f"{persistence_args} "
71
+ f"--input_rgb_dir \"{path_input_dir}\" "
72
+ f"--output_dir \"{path_output_dir}\" "
73
+ f"--n_infer 5 "
74
+ f"--denoise_steps 10 "
75
+ )
76
+
77
+ # depth_colored, depth_bw, depth_npy
78
+ path_out_colored = find_first_png(path_output_dir + "/depth_colored")
79
+ assert path_out_colored is not None, "Processing failed"
80
+ path_out_bw = find_first_png(path_output_dir + "/depth_bw")
81
+ assert path_out_bw is not None, "Processing failed"
82
+
83
+ return path_out_colored, path_out_bw
84
+
85
+
86
+ iface = gr.Interface(
87
+ title="Marigold Depth Estimation",
88
+ description=desc,
89
+ thumbnail="marigold_logo_square.jpg",
90
+ fn=marigold_process,
91
+ inputs=[
92
+ gr.Image(
93
+ label="Input Image",
94
+ type="filepath",
95
+ ),
96
+ gr.File(
97
+ label="Predicted depth (red-near, blue-far)",
98
+ visible=False,
99
+ ),
100
+ gr.File(
101
+ label="Predicted depth (16-bit PNG)",
102
+ visible=False,
103
+ ),
104
+ ],
105
+ outputs=[
106
+ gr.Image(
107
+ label="Predicted depth (red-near, blue-far)",
108
+ type="pil",
109
+ ),
110
+ gr.Image(
111
+ label="Predicted depth (16-bit PNG)",
112
+ type="pil",
113
+ elem_classes="imgdownload",
114
+ ),
115
+ ],
116
+ allow_flagging="never",
117
+ # examples=[
118
+ # [
119
+ # os.path.join(os.path.dirname(__file__), "files/test.png"),
120
+ # os.path.join(os.path.dirname(__file__), "files/test.png.out.png"),
121
+ # os.path.join(os.path.dirname(__file__), "files/test.png.out.2.png"),
122
+ # ],
123
+ # ],
124
+ css="""
125
+ .viewport {
126
+ aspect-ratio: 4/3;
127
+ }
128
+ .imgdownload {
129
+ height: 32px;
130
+ }
131
+ """,
132
+ cache_examples=True,
133
+ )
134
+
135
+
136
+ if __name__ == "__main__":
137
+ init_persistence()
138
+ download_code_weights()
139
+ iface.queue().launch(server_name="0.0.0.0", server_port=7860)
files/Bee_Collecting_Pollen_2004-08-14.jpg ADDED
marigold_logo_square.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.44.4
2
+ gradio_client==0.5.1
3
+ trimesh==3.23.5
4
+
5
+ accelerate
6
+ diffusers==0.20.1
7
+ h5py
8
+ matplotlib
9
+ numpy==1.26.1
10
+ omegaconf
11
+ opencv-python
12
+ pandas
13
+ scipy==1.11.3
14
+ tabulate
15
+ tensorboard
16
+ torch==2.0.1
17
+ torchaudio
18
+ torchvision
19
+ torchshow
20
+ tqdm
21
+ transformers
22
+ triton
23
+ wandb==0.14.0
24
+ xformers