toshas commited on
Commit
536e673
1 Parent(s): 62560ae

update the pipeline to the latest one from diffusers

Browse files

update gradio to 4.21.0
add cvpr acceptance note
support gradio spaces for zero gpu
fix 3d printable artefact to have horizontal orientation
change minimum denoising steps to 10
add a note with the pointer to Marigold-LCM

Files changed (9) hide show
  1. .gitattributes +3 -2
  2. README.md +8 -10
  3. app.py +100 -46
  4. extrude.py +33 -1
  5. files/bee.jpg +0 -0
  6. files/cat.jpg +0 -0
  7. files/swings.jpg +0 -0
  8. marigold_depth_estimation.py +632 -0
  9. requirements.txt +4 -4
.gitattributes CHANGED
@@ -33,5 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- files/einstein_depth_fp32.npy filter=lfs diff=lfs merge=lfs -text
37
- files/einstein_depth_16bit.png filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.npy filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -4,23 +4,21 @@ emoji: 🏵️
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.11.0
8
  app_file: app.py
9
  pinned: true
10
  license: cc-by-sa-4.0
11
  models:
12
- - Bingxin/Marigold
13
  ---
14
 
15
- This is a demo of the monocular depth estimation pipeline, described in the paper titled ["Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation"](https://arxiv.org/abs/2312.02145)
16
 
17
  ```
18
- @misc{ke2023repurposing,
19
- title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
20
- author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
21
- year={2023},
22
- eprint={2312.02145},
23
- archivePrefix={arXiv},
24
- primaryClass={cs.CV}
25
  }
26
  ```
 
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.21.0
8
  app_file: app.py
9
  pinned: true
10
  license: cc-by-sa-4.0
11
  models:
12
+ - prs-eth/marigold-v1-0
13
  ---
14
 
15
+ This is a demo of the monocular depth estimation pipeline, described in the CVPR 2024 paper titled ["Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation"](https://arxiv.org/abs/2312.02145)
16
 
17
  ```
18
+ @InProceedings{ke2023repurposing,
19
+ title={Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation},
20
+ author={Bingxin Ke and Anton Obukhov and Shengyu Huang and Nando Metzger and Rodrigo Caye Daudt and Konrad Schindler},
21
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
22
+ year={2024}
 
 
23
  }
24
  ```
app.py CHANGED
@@ -1,17 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import functools
2
  import os
3
- import shutil
4
- import sys
5
 
6
- import git
7
  import gradio as gr
8
  import numpy as np
9
  import torch as torch
10
  from PIL import Image
11
 
12
  from gradio_imageslider import ImageSlider
 
13
 
14
  from extrude import extrude_depth_3d
 
15
 
16
 
17
  def process(
@@ -82,12 +102,21 @@ def process_3d(
82
  frame_far,
83
  ):
84
  if input_image is None or len(files) < 1:
85
- raise gr.Error("Please upload an image (or use examples) and compute depth first")
 
 
86
 
87
  if plane_near >= plane_far:
88
  raise gr.Error("NEAR plane must have a value smaller than the FAR plane")
89
 
90
- def _process_3d(size_longest_px, filter_size, vertex_colors, scene_lights, output_model_scale=None):
 
 
 
 
 
 
 
91
  image_rgb = input_image
92
  image_depth = files[0]
93
 
@@ -105,14 +134,18 @@ def process_3d(
105
  image_rgb_content.resize((image_new_w, image_new_h), Image.LANCZOS).save(
106
  image_rgb_new
107
  )
108
- Image.open(image_depth).resize((image_new_w, image_new_h), Image.LANCZOS).save(
109
  image_depth_new
110
  )
111
 
112
  path_glb, path_stl = extrude_depth_3d(
113
  image_rgb_new,
114
  image_depth_new,
115
- output_model_scale=size_longest_cm * 10 if output_model_scale is None else output_model_scale,
 
 
 
 
116
  filter_size=filter_size,
117
  coef_near=plane_near,
118
  coef_far=plane_far,
@@ -122,24 +155,27 @@ def process_3d(
122
  f_back=frame_far / 100,
123
  vertex_colors=vertex_colors,
124
  scene_lights=scene_lights,
 
125
  )
126
 
127
  return path_glb, path_stl
128
 
129
- path_viewer_glb, _ = _process_3d(256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1)
130
- path_files_glb, path_files_stl = _process_3d(size_longest_px, filter_size, vertex_colors=True, scene_lights=False)
131
-
132
- # sanitize 3d viewer glb path to keep babylon.js happy
133
- path_viewer_glb_sanitized = os.path.join(os.path.dirname(path_viewer_glb), "preview.glb")
134
- if path_viewer_glb_sanitized != path_viewer_glb:
135
- os.rename(path_viewer_glb, path_viewer_glb_sanitized)
136
- path_viewer_glb = path_viewer_glb_sanitized
 
 
137
 
138
  return path_viewer_glb, [path_files_glb, path_files_stl]
139
 
140
 
141
  def run_demo_server(pipe):
142
- process_pipe = functools.partial(process, pipe)
143
  os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
144
 
145
  with gr.Blocks(
@@ -156,11 +192,24 @@ def run_demo_server(pipe):
156
  .viewport {
157
  aspect-ratio: 4/3;
158
  }
 
 
 
 
 
 
 
 
 
 
 
 
159
  """,
160
  ) as demo:
161
  gr.Markdown(
162
  """
163
- <h1 align="center">Marigold Depth Estimation</h1>
 
164
  <p align="center">
165
  <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
166
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
@@ -175,12 +224,15 @@ def run_demo_server(pipe):
175
  <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
176
  </a>
177
  </p>
178
- <p align="justify">
179
- Marigold is the new state-of-the-art depth estimator for images in the wild.
180
- Upload your image into the <b>left</b> side, or click any of the <b>examples</b> below.
181
- The result will be computed and appear on the <b>right</b> in the output comparison window.
182
- <b style="color: red;">NEW</b>: Scroll down to the new 3D printing part of the demo!
183
- </p>
 
 
 
184
  """
185
  )
186
 
@@ -200,7 +252,7 @@ def run_demo_server(pipe):
200
  )
201
  denoise_steps = gr.Slider(
202
  label="Number of denoising steps",
203
- minimum=1,
204
  maximum=20,
205
  step=1,
206
  value=10,
@@ -356,8 +408,17 @@ def run_demo_server(pipe):
356
  )
357
 
358
  blocks_settings_depth = [ensemble_size, denoise_steps, processing_res]
359
- blocks_settings_3d = [plane_near, plane_far, embossing, size_longest_px, size_longest_cm, filter_size,
360
- frame_thickness, frame_near, frame_far]
 
 
 
 
 
 
 
 
 
361
  blocks_settings = blocks_settings_depth + blocks_settings_3d
362
  map_id_to_default = {b._id: b.value for b in blocks_settings}
363
 
@@ -470,14 +531,21 @@ def run_demo_server(pipe):
470
  gr.Button(interactive=True),
471
  gr.Button(interactive=True),
472
  gr.Image(value=None, interactive=True),
473
- None, None, None, None, None, None, None,
 
 
 
 
 
 
474
  ]
475
  return out
476
 
477
  clear_btn.click(
478
  fn=clear_fn,
479
  inputs=[],
480
- outputs=blocks_settings + [
 
481
  submit_btn,
482
  submit_3d,
483
  input_image,
@@ -532,37 +600,23 @@ def run_demo_server(pipe):
532
  )
533
 
534
 
535
- def prefetch_hf_cache(pipe):
536
- process(pipe, "files/bee.jpg", 1, 1, 64)
537
- shutil.rmtree("files/bee_output")
538
-
539
-
540
  def main():
541
- REPO_URL = "https://github.com/prs-eth/Marigold.git"
542
- REPO_HASH = "02cdfa52"
543
- REPO_DIR = "Marigold"
544
- CHECKPOINT = "Bingxin/Marigold"
545
 
546
- if os.path.isdir(REPO_DIR):
547
- shutil.rmtree(REPO_DIR)
548
- repo = git.Repo.clone_from(REPO_URL, REPO_DIR)
549
- repo.git.checkout(REPO_HASH)
550
-
551
- sys.path.append(os.path.join(os.getcwd(), REPO_DIR))
552
-
553
- from marigold import MarigoldPipeline
554
 
555
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
556
 
557
  pipe = MarigoldPipeline.from_pretrained(CHECKPOINT)
558
  try:
559
  import xformers
 
560
  pipe.enable_xformers_memory_efficient_attention()
561
  except:
562
  pass # run without xformers
563
 
564
  pipe = pipe.to(device)
565
- prefetch_hf_cache(pipe)
566
  run_demo_server(pipe)
567
 
568
 
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
  import functools
22
  import os
 
 
23
 
24
+ import spaces
25
  import gradio as gr
26
  import numpy as np
27
  import torch as torch
28
  from PIL import Image
29
 
30
  from gradio_imageslider import ImageSlider
31
+ from huggingface_hub import login
32
 
33
  from extrude import extrude_depth_3d
34
+ from marigold_depth_estimation import MarigoldPipeline
35
 
36
 
37
  def process(
 
102
  frame_far,
103
  ):
104
  if input_image is None or len(files) < 1:
105
+ raise gr.Error(
106
+ "Please upload an image (or use examples) and compute depth first"
107
+ )
108
 
109
  if plane_near >= plane_far:
110
  raise gr.Error("NEAR plane must have a value smaller than the FAR plane")
111
 
112
+ def _process_3d(
113
+ size_longest_px,
114
+ filter_size,
115
+ vertex_colors,
116
+ scene_lights,
117
+ output_model_scale=None,
118
+ prepare_for_3d_printing=False,
119
+ ):
120
  image_rgb = input_image
121
  image_depth = files[0]
122
 
 
134
  image_rgb_content.resize((image_new_w, image_new_h), Image.LANCZOS).save(
135
  image_rgb_new
136
  )
137
+ Image.open(image_depth).resize((image_new_w, image_new_h), Image.BILINEAR).save(
138
  image_depth_new
139
  )
140
 
141
  path_glb, path_stl = extrude_depth_3d(
142
  image_rgb_new,
143
  image_depth_new,
144
+ output_model_scale=(
145
+ size_longest_cm * 10
146
+ if output_model_scale is None
147
+ else output_model_scale
148
+ ),
149
  filter_size=filter_size,
150
  coef_near=plane_near,
151
  coef_far=plane_far,
 
155
  f_back=frame_far / 100,
156
  vertex_colors=vertex_colors,
157
  scene_lights=scene_lights,
158
+ prepare_for_3d_printing=prepare_for_3d_printing,
159
  )
160
 
161
  return path_glb, path_stl
162
 
163
+ path_viewer_glb, _ = _process_3d(
164
+ 256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1
165
+ )
166
+ path_files_glb, path_files_stl = _process_3d(
167
+ size_longest_px,
168
+ filter_size,
169
+ vertex_colors=True,
170
+ scene_lights=False,
171
+ prepare_for_3d_printing=True,
172
+ )
173
 
174
  return path_viewer_glb, [path_files_glb, path_files_stl]
175
 
176
 
177
  def run_demo_server(pipe):
178
+ process_pipe = spaces.GPU(functools.partial(process, pipe), duration=120)
179
  os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
180
 
181
  with gr.Blocks(
 
192
  .viewport {
193
  aspect-ratio: 4/3;
194
  }
195
+ h1 {
196
+ text-align: center;
197
+ display: block;
198
+ }
199
+ h2 {
200
+ text-align: center;
201
+ display: block;
202
+ }
203
+ h3 {
204
+ text-align: center;
205
+ display: block;
206
+ }
207
  """,
208
  ) as demo:
209
  gr.Markdown(
210
  """
211
+ # Marigold Depth Estimation
212
+
213
  <p align="center">
214
  <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
215
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
 
224
  <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
225
  </a>
226
  </p>
227
+
228
+ Marigold is the state-of-the-art depth estimator for images in the wild.
229
+ Upload your image into the <b>first</b> pane, or click any of the <b>examples</b> below.
230
+ The result will be computed and appear in the <b>second</b> pane.
231
+ Scroll down to use the computed depth map for creating a 3D printable asset.
232
+
233
+ <a href="https://huggingface.co/spaces/prs-eth/marigold-lcm" style="color: crimson;">
234
+ <h3 style="color: crimson;">Check out Marigold-LCM — a FAST version of this demo!<h3>
235
+ </a>
236
  """
237
  )
238
 
 
252
  )
253
  denoise_steps = gr.Slider(
254
  label="Number of denoising steps",
255
+ minimum=10,
256
  maximum=20,
257
  step=1,
258
  value=10,
 
408
  )
409
 
410
  blocks_settings_depth = [ensemble_size, denoise_steps, processing_res]
411
+ blocks_settings_3d = [
412
+ plane_near,
413
+ plane_far,
414
+ embossing,
415
+ size_longest_px,
416
+ size_longest_cm,
417
+ filter_size,
418
+ frame_thickness,
419
+ frame_near,
420
+ frame_far,
421
+ ]
422
  blocks_settings = blocks_settings_depth + blocks_settings_3d
423
  map_id_to_default = {b._id: b.value for b in blocks_settings}
424
 
 
531
  gr.Button(interactive=True),
532
  gr.Button(interactive=True),
533
  gr.Image(value=None, interactive=True),
534
+ None,
535
+ None,
536
+ None,
537
+ None,
538
+ None,
539
+ None,
540
+ None,
541
  ]
542
  return out
543
 
544
  clear_btn.click(
545
  fn=clear_fn,
546
  inputs=[],
547
+ outputs=blocks_settings
548
+ + [
549
  submit_btn,
550
  submit_3d,
551
  input_image,
 
600
  )
601
 
602
 
 
 
 
 
 
603
  def main():
604
+ CHECKPOINT = "prs-eth/marigold-v1-0"
 
 
 
605
 
606
+ if "HF_TOKEN_LOGIN" in os.environ:
607
+ login(token=os.environ["HF_TOKEN_LOGIN"])
 
 
 
 
 
 
608
 
609
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
610
 
611
  pipe = MarigoldPipeline.from_pretrained(CHECKPOINT)
612
  try:
613
  import xformers
614
+
615
  pipe.enable_xformers_memory_efficient_attention()
616
  except:
617
  pass # run without xformers
618
 
619
  pipe = pipe.to(device)
 
620
  run_demo_server(pipe)
621
 
622
 
extrude.py CHANGED
@@ -1,3 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import os
3
 
@@ -53,7 +73,12 @@ def glb_add_lights(path_input, path_output):
53
  angle = i * angle_step
54
 
55
  pos_rot = [0.0, 0.0, math.sin(angle / 2), math.cos(angle / 2)]
56
- elev_rot = [math.sin(elevation_angle / 2), 0.0, 0.0, math.cos(elevation_angle / 2)]
 
 
 
 
 
57
  rotation = quaternion_multiply(pos_rot, elev_rot)
58
 
59
  node = {
@@ -88,6 +113,7 @@ def extrude_depth_3d(
88
  f_back=0.01,
89
  vertex_colors=True,
90
  scene_lights=True,
 
91
  ):
92
  f_far_inner = -emboss
93
  f_far_outer = f_far_inner - f_back
@@ -309,6 +335,12 @@ def extrude_depth_3d(
309
  scaling_factor = output_model_scale / current_max_dimension
310
  mesh.apply_scale(scaling_factor)
311
 
 
 
 
 
 
 
312
  path_out_base = os.path.splitext(path_depth)[0].replace("_16bit", "")
313
  path_out_glb = path_out_base + ".glb"
314
  path_out_stl = path_out_base + ".stl"
 
1
+ # Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
  import math
22
  import os
23
 
 
73
  angle = i * angle_step
74
 
75
  pos_rot = [0.0, 0.0, math.sin(angle / 2), math.cos(angle / 2)]
76
+ elev_rot = [
77
+ math.sin(elevation_angle / 2),
78
+ 0.0,
79
+ 0.0,
80
+ math.cos(elevation_angle / 2),
81
+ ]
82
  rotation = quaternion_multiply(pos_rot, elev_rot)
83
 
84
  node = {
 
113
  f_back=0.01,
114
  vertex_colors=True,
115
  scene_lights=True,
116
+ prepare_for_3d_printing=False,
117
  ):
118
  f_far_inner = -emboss
119
  f_far_outer = f_far_inner - f_back
 
335
  scaling_factor = output_model_scale / current_max_dimension
336
  mesh.apply_scale(scaling_factor)
337
 
338
+ if prepare_for_3d_printing:
339
+ rotation_mat = trimesh.transformations.rotation_matrix(
340
+ np.radians(90), [-1, 0, 0]
341
+ )
342
+ mesh.apply_transform(rotation_mat)
343
+
344
  path_out_base = os.path.splitext(path_depth)[0].replace("_16bit", "")
345
  path_out_glb = path_out_base + ".glb"
346
  path_out_stl = path_out_base + ".stl"
files/bee.jpg CHANGED

Git LFS Details

  • SHA256: 863fccd5ac347c831520ecbb1331e19bc5cfc3caf51acac8dd9a838262a612df
  • Pointer size: 130 Bytes
  • Size of remote file: 77.9 kB
files/cat.jpg CHANGED

Git LFS Details

  • SHA256: 7da86be40e88f33249ce3d7e31b8e725cdc7c8a7daaf45f2c9349860bb6e5deb
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
files/swings.jpg CHANGED

Git LFS Details

  • SHA256: cae2ac669c948313eae8aca53017f10b64b42f87c53b9c34639962b218fdf1f1
  • Pointer size: 131 Bytes
  • Size of remote file: 353 kB
marigold_depth_estimation.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bingxin Ke, ETH Zurich and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import math
22
+ from typing import Dict, Union
23
+
24
+ import matplotlib
25
+ import numpy as np
26
+ import torch
27
+ from PIL import Image
28
+ from scipy.optimize import minimize
29
+ from torch.utils.data import DataLoader, TensorDataset
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from diffusers import (
34
+ AutoencoderKL,
35
+ DDIMScheduler,
36
+ DiffusionPipeline,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.utils import BaseOutput, check_min_version
40
+
41
+
42
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
43
+ check_min_version("0.27.0.dev0")
44
+
45
+
46
+ class MarigoldDepthOutput(BaseOutput):
47
+ """
48
+ Output class for Marigold monocular depth prediction pipeline.
49
+
50
+ Args:
51
+ depth_np (`np.ndarray`):
52
+ Predicted depth map, with depth values in the range of [0, 1].
53
+ depth_colored (`None` or `PIL.Image.Image`):
54
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
55
+ uncertainty (`None` or `np.ndarray`):
56
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
57
+ """
58
+
59
+ depth_np: np.ndarray
60
+ depth_colored: Union[None, Image.Image]
61
+ uncertainty: Union[None, np.ndarray]
62
+
63
+
64
+ class MarigoldPipeline(DiffusionPipeline):
65
+ """
66
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
67
+
68
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
69
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
70
+
71
+ Args:
72
+ unet (`UNet2DConditionModel`):
73
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
74
+ vae (`AutoencoderKL`):
75
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
76
+ to and from latent representations.
77
+ scheduler (`DDIMScheduler`):
78
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
79
+ text_encoder (`CLIPTextModel`):
80
+ Text-encoder, for empty text embedding.
81
+ tokenizer (`CLIPTokenizer`):
82
+ CLIP tokenizer.
83
+ """
84
+
85
+ rgb_latent_scale_factor = 0.18215
86
+ depth_latent_scale_factor = 0.18215
87
+
88
+ def __init__(
89
+ self,
90
+ unet: UNet2DConditionModel,
91
+ vae: AutoencoderKL,
92
+ scheduler: DDIMScheduler,
93
+ text_encoder: CLIPTextModel,
94
+ tokenizer: CLIPTokenizer,
95
+ ):
96
+ super().__init__()
97
+
98
+ self.register_modules(
99
+ unet=unet,
100
+ vae=vae,
101
+ scheduler=scheduler,
102
+ text_encoder=text_encoder,
103
+ tokenizer=tokenizer,
104
+ )
105
+
106
+ self.empty_text_embed = None
107
+
108
+ @torch.no_grad()
109
+ def __call__(
110
+ self,
111
+ input_image: Image,
112
+ denoising_steps: int = 10,
113
+ ensemble_size: int = 10,
114
+ processing_res: int = 768,
115
+ match_input_res: bool = True,
116
+ batch_size: int = 0,
117
+ color_map: str = "Spectral",
118
+ show_progress_bar: bool = True,
119
+ ensemble_kwargs: Dict = None,
120
+ ) -> MarigoldDepthOutput:
121
+ """
122
+ Function invoked when calling the pipeline.
123
+
124
+ Args:
125
+ input_image (`Image`):
126
+ Input RGB (or gray-scale) image.
127
+ processing_res (`int`, *optional*, defaults to `768`):
128
+ Maximum resolution of processing.
129
+ If set to 0: will not resize at all.
130
+ match_input_res (`bool`, *optional*, defaults to `True`):
131
+ Resize depth prediction to match input resolution.
132
+ Only valid if `limit_input_res` is not None.
133
+ denoising_steps (`int`, *optional*, defaults to `10`):
134
+ Number of diffusion denoising steps (DDIM) during inference.
135
+ ensemble_size (`int`, *optional*, defaults to `10`):
136
+ Number of predictions to be ensembled.
137
+ batch_size (`int`, *optional*, defaults to `0`):
138
+ Inference batch size, no bigger than `num_ensemble`.
139
+ If set to 0, the script will automatically decide the proper batch size.
140
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
141
+ Display a progress bar of diffusion denoising.
142
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
143
+ Colormap used to colorize the depth map.
144
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
145
+ Arguments for detailed ensembling settings.
146
+ Returns:
147
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
148
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
149
+ - **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
150
+ values in [0, 1]. None if `color_map` is `None`
151
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
152
+ coming from ensembling. None if `ensemble_size = 1`
153
+ """
154
+
155
+ device = self.device
156
+ input_size = input_image.size
157
+
158
+ if not match_input_res:
159
+ assert (
160
+ processing_res is not None
161
+ ), "Value error: `resize_output_back` is only valid with "
162
+ assert processing_res >= 0
163
+ assert denoising_steps >= 1
164
+ assert ensemble_size >= 1
165
+
166
+ # ----------------- Image Preprocess -----------------
167
+ # Resize image
168
+ if processing_res > 0:
169
+ input_image = self.resize_max_res(
170
+ input_image, max_edge_resolution=processing_res
171
+ )
172
+ # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
173
+ input_image = input_image.convert("RGB")
174
+ image = np.asarray(input_image)
175
+
176
+ # Normalize rgb values
177
+ rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
178
+ rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
179
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
180
+ rgb_norm = rgb_norm.to(device)
181
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
182
+
183
+ # ----------------- Predicting depth -----------------
184
+ # Batch repeated input image
185
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
186
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
187
+ if batch_size > 0:
188
+ _bs = batch_size
189
+ else:
190
+ _bs = self._find_batch_size(
191
+ ensemble_size=ensemble_size,
192
+ input_res=max(rgb_norm.shape[1:]),
193
+ dtype=self.dtype,
194
+ )
195
+
196
+ single_rgb_loader = DataLoader(
197
+ single_rgb_dataset, batch_size=_bs, shuffle=False
198
+ )
199
+
200
+ # Predict depth maps (batched)
201
+ depth_pred_ls = []
202
+ if show_progress_bar:
203
+ iterable = tqdm(
204
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
205
+ )
206
+ else:
207
+ iterable = single_rgb_loader
208
+ for batch in iterable:
209
+ (batched_img,) = batch
210
+ depth_pred_raw = self.single_infer(
211
+ rgb_in=batched_img,
212
+ num_inference_steps=denoising_steps,
213
+ show_pbar=show_progress_bar,
214
+ )
215
+ depth_pred_ls.append(depth_pred_raw.detach().clone())
216
+ depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
217
+ torch.cuda.empty_cache() # clear vram cache for ensembling
218
+
219
+ # ----------------- Test-time ensembling -----------------
220
+ if ensemble_size > 1:
221
+ depth_pred, pred_uncert = self.ensemble_depths(
222
+ depth_preds, **(ensemble_kwargs or {})
223
+ )
224
+ else:
225
+ depth_pred = depth_preds
226
+ pred_uncert = None
227
+
228
+ # ----------------- Post processing -----------------
229
+ # Scale prediction to [0, 1]
230
+ min_d = torch.min(depth_pred)
231
+ max_d = torch.max(depth_pred)
232
+ depth_pred = (depth_pred - min_d) / (max_d - min_d)
233
+
234
+ # Convert to numpy
235
+ depth_pred = depth_pred.cpu().numpy().astype(np.float32)
236
+
237
+ # Resize back to original resolution
238
+ if match_input_res:
239
+ pred_img = Image.fromarray(depth_pred)
240
+ pred_img = pred_img.resize(input_size)
241
+ depth_pred = np.asarray(pred_img)
242
+
243
+ # Clip output range
244
+ depth_pred = depth_pred.clip(0, 1)
245
+
246
+ # Colorize
247
+ if color_map is not None:
248
+ depth_colored = self.colorize_depth_maps(
249
+ depth_pred, 0, 1, cmap=color_map
250
+ ).squeeze() # [3, H, W], value in (0, 1)
251
+ depth_colored = (depth_colored * 255).astype(np.uint8)
252
+ depth_colored_hwc = self.chw2hwc(depth_colored)
253
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
254
+ else:
255
+ depth_colored_img = None
256
+ return MarigoldDepthOutput(
257
+ depth_np=depth_pred,
258
+ depth_colored=depth_colored_img,
259
+ uncertainty=pred_uncert,
260
+ )
261
+
262
+ def _encode_empty_text(self):
263
+ """
264
+ Encode text embedding for empty prompt.
265
+ """
266
+ prompt = ""
267
+ text_inputs = self.tokenizer(
268
+ prompt,
269
+ padding="do_not_pad",
270
+ max_length=self.tokenizer.model_max_length,
271
+ truncation=True,
272
+ return_tensors="pt",
273
+ )
274
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
275
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
276
+
277
+ @torch.no_grad()
278
+ def single_infer(
279
+ self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool
280
+ ) -> torch.Tensor:
281
+ """
282
+ Perform an individual depth prediction without ensembling.
283
+
284
+ Args:
285
+ rgb_in (`torch.Tensor`):
286
+ Input RGB image.
287
+ num_inference_steps (`int`):
288
+ Number of diffusion denoisign steps (DDIM) during inference.
289
+ show_pbar (`bool`):
290
+ Display a progress bar of diffusion denoising.
291
+ Returns:
292
+ `torch.Tensor`: Predicted depth map.
293
+ """
294
+ device = rgb_in.device
295
+
296
+ # Set timesteps
297
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
298
+ timesteps = self.scheduler.timesteps # [T]
299
+
300
+ # Encode image
301
+ rgb_latent = self._encode_rgb(rgb_in)
302
+
303
+ # Initial depth map (noise)
304
+ depth_latent = torch.randn(
305
+ rgb_latent.shape, device=device, dtype=self.dtype
306
+ ) # [B, 4, h, w]
307
+
308
+ # Batched empty text embedding
309
+ if self.empty_text_embed is None:
310
+ self._encode_empty_text()
311
+ batch_empty_text_embed = self.empty_text_embed.repeat(
312
+ (rgb_latent.shape[0], 1, 1)
313
+ ) # [B, 2, 1024]
314
+
315
+ # Denoising loop
316
+ if show_pbar:
317
+ iterable = tqdm(
318
+ enumerate(timesteps),
319
+ total=len(timesteps),
320
+ leave=False,
321
+ desc=" " * 4 + "Diffusion denoising",
322
+ )
323
+ else:
324
+ iterable = enumerate(timesteps)
325
+
326
+ for i, t in iterable:
327
+ unet_input = torch.cat(
328
+ [rgb_latent, depth_latent], dim=1
329
+ ) # this order is important
330
+
331
+ # predict the noise residual
332
+ noise_pred = self.unet(
333
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
334
+ ).sample # [B, 4, h, w]
335
+
336
+ # compute the previous noisy sample x_t -> x_t-1
337
+ depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
338
+ torch.cuda.empty_cache()
339
+ depth = self._decode_depth(depth_latent)
340
+
341
+ # clip prediction
342
+ depth = torch.clip(depth, -1.0, 1.0)
343
+ # shift to [0, 1]
344
+ depth = (depth + 1.0) / 2.0
345
+
346
+ return depth
347
+
348
+ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
349
+ """
350
+ Encode RGB image into latent.
351
+
352
+ Args:
353
+ rgb_in (`torch.Tensor`):
354
+ Input RGB image to be encoded.
355
+
356
+ Returns:
357
+ `torch.Tensor`: Image latent.
358
+ """
359
+ # encode
360
+ h = self.vae.encoder(rgb_in)
361
+ moments = self.vae.quant_conv(h)
362
+ mean, logvar = torch.chunk(moments, 2, dim=1)
363
+ # scale latent
364
+ rgb_latent = mean * self.rgb_latent_scale_factor
365
+ return rgb_latent
366
+
367
+ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
368
+ """
369
+ Decode depth latent into depth map.
370
+
371
+ Args:
372
+ depth_latent (`torch.Tensor`):
373
+ Depth latent to be decoded.
374
+
375
+ Returns:
376
+ `torch.Tensor`: Decoded depth map.
377
+ """
378
+ # scale latent
379
+ depth_latent = depth_latent / self.depth_latent_scale_factor
380
+ # decode
381
+ z = self.vae.post_quant_conv(depth_latent)
382
+ stacked = self.vae.decoder(z)
383
+ # mean of output channels
384
+ depth_mean = stacked.mean(dim=1, keepdim=True)
385
+ return depth_mean
386
+
387
+ @staticmethod
388
+ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
389
+ """
390
+ Resize image to limit maximum edge length while keeping aspect ratio.
391
+
392
+ Args:
393
+ img (`Image.Image`):
394
+ Image to be resized.
395
+ max_edge_resolution (`int`):
396
+ Maximum edge length (pixel).
397
+
398
+ Returns:
399
+ `Image.Image`: Resized image.
400
+ """
401
+ original_width, original_height = img.size
402
+ downscale_factor = min(
403
+ max_edge_resolution / original_width, max_edge_resolution / original_height
404
+ )
405
+
406
+ new_width = int(original_width * downscale_factor)
407
+ new_height = int(original_height * downscale_factor)
408
+
409
+ resized_img = img.resize((new_width, new_height))
410
+ return resized_img
411
+
412
+ @staticmethod
413
+ def colorize_depth_maps(
414
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
415
+ ):
416
+ """
417
+ Colorize depth maps.
418
+ """
419
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
420
+
421
+ if isinstance(depth_map, torch.Tensor):
422
+ depth = depth_map.detach().clone().squeeze().numpy()
423
+ elif isinstance(depth_map, np.ndarray):
424
+ depth = depth_map.copy().squeeze()
425
+ # reshape to [ (B,) H, W ]
426
+ if depth.ndim < 3:
427
+ depth = depth[np.newaxis, :, :]
428
+
429
+ # colorize
430
+ cm = matplotlib.colormaps[cmap]
431
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
432
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
433
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
434
+
435
+ if valid_mask is not None:
436
+ if isinstance(depth_map, torch.Tensor):
437
+ valid_mask = valid_mask.detach().numpy()
438
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
439
+ if valid_mask.ndim < 3:
440
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
441
+ else:
442
+ valid_mask = valid_mask[:, np.newaxis, :, :]
443
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
444
+ img_colored_np[~valid_mask] = 0
445
+
446
+ if isinstance(depth_map, torch.Tensor):
447
+ img_colored = torch.from_numpy(img_colored_np).float()
448
+ elif isinstance(depth_map, np.ndarray):
449
+ img_colored = img_colored_np
450
+
451
+ return img_colored
452
+
453
+ @staticmethod
454
+ def chw2hwc(chw):
455
+ assert 3 == len(chw.shape)
456
+ if isinstance(chw, torch.Tensor):
457
+ hwc = torch.permute(chw, (1, 2, 0))
458
+ elif isinstance(chw, np.ndarray):
459
+ hwc = np.moveaxis(chw, 0, -1)
460
+ return hwc
461
+
462
+ @staticmethod
463
+ def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
464
+ """
465
+ Automatically search for suitable operating batch size.
466
+
467
+ Args:
468
+ ensemble_size (`int`):
469
+ Number of predictions to be ensembled.
470
+ input_res (`int`):
471
+ Operating resolution of the input image.
472
+
473
+ Returns:
474
+ `int`: Operating batch size.
475
+ """
476
+ # Search table for suggested max. inference batch size
477
+ bs_search_table = [
478
+ # tested on A100-PCIE-80GB
479
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
480
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
481
+ # tested on A100-PCIE-40GB
482
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
483
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
484
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
485
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
486
+ # tested on RTX3090, RTX4090
487
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
488
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
489
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
490
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
491
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
492
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
493
+ # tested on GTX1080Ti
494
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
495
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
496
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
497
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
498
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
499
+ ]
500
+
501
+ if not torch.cuda.is_available():
502
+ return 1
503
+
504
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
505
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
506
+ for settings in sorted(
507
+ filtered_bs_search_table,
508
+ key=lambda k: (k["res"], -k["total_vram"]),
509
+ ):
510
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
511
+ bs = settings["bs"]
512
+ if bs > ensemble_size:
513
+ bs = ensemble_size
514
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
515
+ bs = math.ceil(ensemble_size / 2)
516
+ return bs
517
+
518
+ return 1
519
+
520
+ @staticmethod
521
+ def ensemble_depths(
522
+ input_images: torch.Tensor,
523
+ regularizer_strength: float = 0.02,
524
+ max_iter: int = 2,
525
+ tol: float = 1e-3,
526
+ reduction: str = "median",
527
+ max_res: int = None,
528
+ ):
529
+ """
530
+ To ensemble multiple affine-invariant depth images (up to scale and shift),
531
+ by aligning estimating the scale and shift
532
+ """
533
+
534
+ def inter_distances(tensors: torch.Tensor):
535
+ """
536
+ To calculate the distance between each two depth maps.
537
+ """
538
+ distances = []
539
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
540
+ arr1 = tensors[i : i + 1]
541
+ arr2 = tensors[j : j + 1]
542
+ distances.append(arr1 - arr2)
543
+ dist = torch.concatenate(distances, dim=0)
544
+ return dist
545
+
546
+ device = input_images.device
547
+ dtype = input_images.dtype
548
+ np_dtype = np.float32
549
+
550
+ original_input = input_images.clone()
551
+ n_img = input_images.shape[0]
552
+ ori_shape = input_images.shape
553
+
554
+ if max_res is not None:
555
+ scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
556
+ if scale_factor < 1:
557
+ downscaler = torch.nn.Upsample(
558
+ scale_factor=scale_factor, mode="nearest"
559
+ )
560
+ input_images = downscaler(torch.from_numpy(input_images)).numpy()
561
+
562
+ # init guess
563
+ _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
564
+ _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
565
+ s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
566
+ t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
567
+ x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
568
+
569
+ input_images = input_images.to(device)
570
+
571
+ # objective function
572
+ def closure(x):
573
+ l = len(x)
574
+ s = x[: int(l / 2)]
575
+ t = x[int(l / 2) :]
576
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
577
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
578
+
579
+ transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
580
+ dists = inter_distances(transformed_arrays)
581
+ sqrt_dist = torch.sqrt(torch.mean(dists**2))
582
+
583
+ if "mean" == reduction:
584
+ pred = torch.mean(transformed_arrays, dim=0)
585
+ elif "median" == reduction:
586
+ pred = torch.median(transformed_arrays, dim=0).values
587
+ else:
588
+ raise ValueError
589
+
590
+ near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
591
+ far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
592
+
593
+ err = sqrt_dist + (near_err + far_err) * regularizer_strength
594
+ err = err.detach().cpu().numpy().astype(np_dtype)
595
+ return err
596
+
597
+ res = minimize(
598
+ closure,
599
+ x,
600
+ method="BFGS",
601
+ tol=tol,
602
+ options={"maxiter": max_iter, "disp": False},
603
+ )
604
+ x = res.x
605
+ l = len(x)
606
+ s = x[: int(l / 2)]
607
+ t = x[int(l / 2) :]
608
+
609
+ # Prediction
610
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
611
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
612
+ transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
613
+ if "mean" == reduction:
614
+ aligned_images = torch.mean(transformed_arrays, dim=0)
615
+ std = torch.std(transformed_arrays, dim=0)
616
+ uncertainty = std
617
+ elif "median" == reduction:
618
+ aligned_images = torch.median(transformed_arrays, dim=0).values
619
+ # MAD (median absolute deviation) as uncertainty indicator
620
+ abs_dev = torch.abs(transformed_arrays - aligned_images)
621
+ mad = torch.median(abs_dev, dim=0).values
622
+ uncertainty = mad
623
+ else:
624
+ raise ValueError(f"Unknown reduction method: {reduction}")
625
+
626
+ # Scale and shift to [0, 1]
627
+ _min = torch.min(aligned_images)
628
+ _max = torch.max(aligned_images)
629
+ aligned_images = (aligned_images - _min) / (_max - _min)
630
+ uncertainty /= _max - _min
631
+
632
+ return aligned_images, uncertainty
requirements.txt CHANGED
@@ -1,13 +1,13 @@
1
- gradio==4.11.0
2
  gradio-imageslider==0.0.16
3
- GitPython==3.1.40
4
  pygltflib==1.16.1
5
  trimesh==4.0.5
6
 
 
7
  accelerate>=0.22.0
8
- diffusers>=0.20.1
9
  matplotlib==3.8.2
10
  scipy==1.11.4
11
  torch==2.0.1
12
  transformers>=4.32.1
13
- xformers==0.0.21
 
1
+ gradio==4.21.0
2
  gradio-imageslider==0.0.16
 
3
  pygltflib==1.16.1
4
  trimesh==4.0.5
5
 
6
+ spaces>=0.25.0
7
  accelerate>=0.22.0
8
+ diffusers==0.27.2
9
  matplotlib==3.8.2
10
  scipy==1.11.4
11
  torch==2.0.1
12
  transformers>=4.32.1
13
+ xformers>=0.0.21