GonzaloMG commited on
Commit
aaf39d1
·
verified ·
1 Parent(s): fbad7a8

Upload 8 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/examples/mushrooms.png filter=lfs diff=lfs merge=lfs -text
Marigold/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Code is copied from https://github.com/prs-eth/Marigold. Modifications are indicated within the code.
Marigold/marigold/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
Marigold/marigold/marigold_pipeline.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+ # @GonzaloMartinGarcia
21
+ # This file is a modified version of the original Marigold pipeline file.
22
+ # Based on GeoWizard, we added the option to sample surface normals, marked with # add.
23
+
24
+ from typing import Dict, Union
25
+
26
+ import numpy as np
27
+ import torch
28
+ from diffusers import (
29
+ AutoencoderKL,
30
+ DDIMScheduler,
31
+ DiffusionPipeline,
32
+ LCMScheduler,
33
+ UNet2DConditionModel,
34
+ DDPMScheduler,
35
+ )
36
+ from diffusers.utils import BaseOutput
37
+ from PIL import Image
38
+ from torchvision.transforms.functional import resize, pil_to_tensor
39
+ from torchvision.transforms import InterpolationMode
40
+ from torch.utils.data import DataLoader, TensorDataset
41
+ from tqdm.auto import tqdm
42
+ from transformers import CLIPTextModel, CLIPTokenizer
43
+
44
+ from .util.batchsize import find_batch_size
45
+ from .util.ensemble import ensemble_depths
46
+ from .util.image_util import (
47
+ chw2hwc,
48
+ colorize_depth_maps,
49
+ get_tv_resample_method,
50
+ resize_max_res,
51
+ )
52
+
53
+ # add
54
+ import random
55
+
56
+
57
+ # add
58
+ # Surface Normals Ensamble from the GeoWizard github repository (https://github.com/fuxiao0719/GeoWizard)
59
+ def ensemble_normals(input_images:torch.Tensor):
60
+ normal_preds = input_images
61
+ bsz, d, h, w = normal_preds.shape
62
+ normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
63
+ phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
64
+ theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
65
+ normal_pred = torch.zeros((d,h,w)).to(normal_preds)
66
+ normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
67
+ normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
68
+ normal_pred[2,:,:] = torch.cos(theta)
69
+ angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999))
70
+ normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
71
+ return normal_preds[normal_idx], None
72
+
73
+ # add
74
+ # Pyramid nosie from
75
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31
76
+ def pyramid_noise_like(x, discount=0.9):
77
+ b, c, w, h = x.shape
78
+ u = torch.nn.Upsample(size=(w, h), mode='bilinear')
79
+ noise = torch.randn_like(x)
80
+ for i in range(10):
81
+ r = random.random()*2+2
82
+ w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
83
+ noise += u(torch.randn(b, c, w, h).to(x)) * discount**i
84
+ if w==1 or h==1:
85
+ break
86
+ return noise / noise.std()
87
+
88
+ class MarigoldDepthOutput(BaseOutput):
89
+ """
90
+ Output class for Marigold monocular depth prediction pipeline.
91
+
92
+ Args:
93
+ depth_np (`np.ndarray`):
94
+ Predicted depth map, with depth values in the range of [0, 1].
95
+ depth_colored (`PIL.Image.Image`):
96
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
97
+ uncertainty (`None` or `np.ndarray`):
98
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
99
+ normal_np (`np.ndarray`):
100
+ Predicted normal map, with normal vectors in the range of [-1, 1].
101
+ normal_colored (`PIL.Image.Image`):
102
+ Colorized normal map
103
+ """
104
+
105
+ depth_np: np.ndarray
106
+ depth_colored: Union[None, Image.Image]
107
+ uncertainty: Union[None, np.ndarray]
108
+ # add
109
+ normal_np: np.ndarray
110
+ normal_colored: Union[None, Image.Image]
111
+
112
+
113
+ class MarigoldPipeline(DiffusionPipeline):
114
+ """
115
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
116
+
117
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
118
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
119
+
120
+ Args:
121
+ unet (`UNet2DConditionModel`):
122
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
123
+ vae (`AutoencoderKL`):
124
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
125
+ to and from latent representations.
126
+ scheduler (`DDIMScheduler`):
127
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
128
+ text_encoder (`CLIPTextModel`):
129
+ Text-encoder, for empty text embedding.
130
+ tokenizer (`CLIPTokenizer`):
131
+ CLIP tokenizer.
132
+ """
133
+
134
+ rgb_latent_scale_factor = 0.18215
135
+ depth_latent_scale_factor = 0.18215
136
+
137
+ def __init__(
138
+ self,
139
+ unet: UNet2DConditionModel,
140
+ vae: AutoencoderKL,
141
+ scheduler: Union[DDIMScheduler,DDPMScheduler,LCMScheduler],
142
+ text_encoder: CLIPTextModel,
143
+ tokenizer: CLIPTokenizer,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.register_modules(
148
+ unet=unet,
149
+ vae=vae,
150
+ scheduler=scheduler,
151
+ text_encoder=text_encoder,
152
+ tokenizer=tokenizer,
153
+ )
154
+
155
+ self.empty_text_embed = None
156
+
157
+ @torch.no_grad()
158
+ def __call__(
159
+ self,
160
+ input_image: Union[Image.Image, torch.Tensor],
161
+ denoising_steps: int = 10,
162
+ ensemble_size: int = 10,
163
+ processing_res: int = 768,
164
+ match_input_res: bool = True,
165
+ resample_method: str = "bilinear",
166
+ batch_size: int = 0,
167
+ color_map: str = "Spectral",
168
+ show_progress_bar: bool = True,
169
+ ensemble_kwargs: Dict = None,
170
+ # add
171
+ noise="gaussian",
172
+ normals=False,
173
+ ) -> MarigoldDepthOutput:
174
+ """
175
+ Function invoked when calling the pipeline.
176
+
177
+ Args:
178
+ input_image (`Image`):
179
+ Input RGB (or gray-scale) image.
180
+ processing_res (`int`, *optional*, defaults to `768`):
181
+ Maximum resolution of processing.
182
+ If set to 0: will not resize at all.
183
+ match_input_res (`bool`, *optional*, defaults to `True`):
184
+ Resize depth prediction to match input resolution.
185
+ Only valid if `processing_res` > 0.
186
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
187
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
188
+ denoising_steps (`int`, *optional*, defaults to `10`):
189
+ Number of diffusion denoising steps (DDIM) during inference.
190
+ ensemble_size (`int`, *optional*, defaults to `10`):
191
+ Number of predictions to be ensembled.
192
+ batch_size (`int`, *optional*, defaults to `0`):
193
+ Inference batch size, no bigger than `num_ensemble`.
194
+ If set to 0, the script will automatically decide the proper batch size.
195
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
196
+ Display a progress bar of diffusion denoising.
197
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
198
+ Colormap used to colorize the depth map.
199
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
200
+ Arguments for detailed ensembling settings.
201
+ noise (`str`, *optional*, defaults to `gaussian`):
202
+ Type of noise to be used for the initial depth map.
203
+ Can be one of `gaussian`, `pyramid`, `zeros`.
204
+ normals (`bool`, *optional*, defaults to `False`):
205
+ If `True`, the pipeline will predict surface normals instead of depth maps.
206
+ Returns:
207
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
208
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
209
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
210
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
211
+ coming from ensembling. None if `ensemble_size = 1`
212
+ - **normal_np** (`np.ndarray`) Predicted normal map, with normal vectors in the range of [-1, 1]
213
+ - **normal_colored** (`PIL.Image.Image`) Colorized normal map
214
+ """
215
+
216
+ assert processing_res >= 0
217
+ assert ensemble_size >= 1
218
+
219
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
220
+
221
+ # ----------------- Image Preprocess -----------------
222
+
223
+ # Convert to torch tensor
224
+ if isinstance(input_image, Image.Image):
225
+ input_image = input_image.convert("RGB")
226
+ rgb = pil_to_tensor(input_image) # [H, W, rgb] -> [rgb, H, W]
227
+ elif isinstance(input_image, torch.Tensor):
228
+ rgb = input_image.squeeze()
229
+ else:
230
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
231
+ input_size = rgb.shape
232
+ assert (
233
+ 3 == rgb.dim() and 3 == input_size[0]
234
+ ), f"Wrong input shape {input_size}, expected [rgb, H, W]"
235
+
236
+ # Resize image
237
+ if processing_res > 0:
238
+ rgb = resize_max_res(
239
+ rgb,
240
+ max_edge_resolution=processing_res,
241
+ resample_method=resample_method,
242
+ )
243
+
244
+ # Normalize rgb values
245
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
246
+ rgb_norm = rgb_norm.to(self.dtype)
247
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
248
+
249
+ # ----------------- Predicting depth/normal --------------
250
+
251
+ # Batch repeated input image
252
+ duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
253
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
254
+ if batch_size > 0:
255
+ _bs = batch_size
256
+ else:
257
+ _bs = find_batch_size(
258
+ ensemble_size=ensemble_size,
259
+ input_res=max(rgb_norm.shape[1:]),
260
+ dtype=self.dtype,
261
+ )
262
+
263
+ single_rgb_loader = DataLoader(
264
+ single_rgb_dataset, batch_size=_bs, shuffle=False
265
+ )
266
+
267
+ # load iterator
268
+ pred_ls = []
269
+ if show_progress_bar:
270
+ iterable = tqdm(
271
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
272
+ )
273
+ else:
274
+ iterable = single_rgb_loader
275
+
276
+ # inference (batched)
277
+ for batch in iterable:
278
+ (batched_img,) = batch
279
+ pred_raw = self.single_infer(
280
+ rgb_in=batched_img,
281
+ num_inference_steps=denoising_steps,
282
+ show_pbar=show_progress_bar,
283
+ # add
284
+ noise=noise,
285
+ normals=normals,
286
+ )
287
+ pred_ls.append(pred_raw.detach())
288
+ preds = torch.concat(pred_ls, dim=0).squeeze()
289
+ torch.cuda.empty_cache() # clear vram cache for ensembling
290
+
291
+ # ----------------- Test-time ensembling -----------------
292
+
293
+ if ensemble_size > 1: # add
294
+ pred, pred_uncert = ensemble_normals(preds) if normals else ensemble_depths(preds, **(ensemble_kwargs or {}))
295
+ else:
296
+ pred = preds
297
+ pred_uncert = None
298
+
299
+ # ----------------- Post processing -----------------
300
+
301
+ if normals:
302
+ # add
303
+ # Normalizae normal vectors to unit length
304
+ pred /= (torch.norm(pred, p=2, dim=0, keepdim=True)+1e-5)
305
+ else:
306
+ # Scale relative prediction to [0, 1]
307
+ min_d = torch.min(pred)
308
+ max_d = torch.max(pred)
309
+ if max_d == min_d:
310
+ pred = torch.zeros_like(pred)
311
+ else:
312
+ pred = (pred - min_d) / (max_d - min_d)
313
+
314
+ # Resize back to original resolution
315
+ if match_input_res:
316
+ pred = resize(
317
+ pred if normals else pred.unsqueeze(0),
318
+ (input_size[-2],input_size[-1]),
319
+ interpolation=resample_method,
320
+ antialias=True,
321
+ ).squeeze()
322
+
323
+ # Convert to numpy
324
+ pred = pred.cpu().numpy()
325
+
326
+ # Process prediction for visualization
327
+ if not normals:
328
+ # add
329
+ pred = pred.clip(0, 1)
330
+ if color_map is not None:
331
+ colored = colorize_depth_maps(
332
+ pred, 0, 1, cmap=color_map
333
+ ).squeeze() # [3, H, W], value in (0, 1)
334
+ colored = (colored * 255).astype(np.uint8)
335
+ colored_hwc = chw2hwc(colored)
336
+ colored_img = Image.fromarray(colored_hwc)
337
+ else:
338
+ colored_img = None
339
+ else:
340
+ pred = pred.clip(-1.0, 1.0)
341
+ colored = (((pred+1)/2) * 255).astype(np.uint8)
342
+ colored_hwc = chw2hwc(colored)
343
+ colored_img = Image.fromarray(colored_hwc)
344
+
345
+
346
+ return MarigoldDepthOutput(
347
+ depth_np = pred if not normals else None,
348
+ depth_colored = colored_img if not normals else None,
349
+ uncertainty = pred_uncert,
350
+ # add
351
+ normal_np = pred if normals else None,
352
+ normal_colored = colored_img if normals else None,
353
+ )
354
+
355
+
356
+ def encode_empty_text(self):
357
+ """
358
+ Encode text embedding for empty prompt
359
+ """
360
+ prompt = ""
361
+ text_inputs = self.tokenizer(
362
+ prompt,
363
+ padding="do_not_pad",
364
+ max_length=self.tokenizer.model_max_length,
365
+ truncation=True,
366
+ return_tensors="pt",
367
+ )
368
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
369
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
370
+
371
+ @torch.no_grad()
372
+ def single_infer(
373
+ self,
374
+ rgb_in: torch.Tensor,
375
+ num_inference_steps: int,
376
+ show_pbar: bool,
377
+ # add
378
+ noise="gaussian",
379
+ normals=False,
380
+ ) -> torch.Tensor:
381
+ """
382
+ Perform an individual depth prediction without ensembling.
383
+
384
+ Args:
385
+ rgb_in (`torch.Tensor`):
386
+ Input RGB image.
387
+ num_inference_steps (`int`):
388
+ Number of diffusion denoisign steps (DDIM) during inference.
389
+ show_pbar (`bool`):
390
+ Display a progress bar of diffusion denoising.
391
+ noise (`str`, *optional*, defaults to `gaussian`):
392
+ Type of noise to be used for the initial depth map.
393
+ Can be one of `gaussian`, `pyramid`, `zeros`.
394
+ Returns:
395
+ `torch.Tensor`: Predicted depth map.
396
+ """
397
+ device = self.device
398
+ rgb_in = rgb_in.to(device)
399
+
400
+ # Set timesteps
401
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
402
+ timesteps = self.scheduler.timesteps # [T]
403
+
404
+ # Encode image
405
+ rgb_latent = self.encode_rgb(rgb_in)
406
+
407
+ # add
408
+ # Initial prediction
409
+ latent_shape = rgb_latent.shape
410
+ if noise == "gaussian":
411
+ latent = torch.randn(
412
+ latent_shape,
413
+ device=device,
414
+ dtype=self.dtype,
415
+ )
416
+ elif noise == "pyramid":
417
+ latent = pyramid_noise_like(rgb_latent).to(device) # [B, 4, h, w]
418
+ elif noise == "zeros":
419
+ latent = torch.zeros(
420
+ latent_shape,
421
+ device=device,
422
+ dtype=self.dtype,
423
+ )
424
+ else:
425
+ raise ValueError(f"Unknown noise type: {noise}")
426
+
427
+ # Batched empty text embedding
428
+ if self.empty_text_embed is None:
429
+ self.encode_empty_text()
430
+ batch_empty_text_embed = self.empty_text_embed.repeat(
431
+ (rgb_latent.shape[0], 1, 1)
432
+ ) # [B, 2, 1024]
433
+
434
+ # Denoising loop
435
+ if show_pbar:
436
+ iterable = tqdm(
437
+ enumerate(timesteps),
438
+ total=len(timesteps),
439
+ leave=False,
440
+ desc=" " * 4 + "Diffusion denoising",
441
+ )
442
+ else:
443
+ iterable = enumerate(timesteps)
444
+
445
+ for i, t in iterable:
446
+
447
+ unet_input = torch.cat(
448
+ [rgb_latent, latent], dim=1
449
+ ) # this order is important
450
+
451
+ # predict the noise residual
452
+ noise_pred = self.unet(
453
+ unet_input, t, encoder_hidden_states=batch_empty_text_embed
454
+ ).sample # [B, 4, h, w]
455
+
456
+ # compute the previous noisy sample x_t -> x_t-1
457
+ scheduler_step = self.scheduler.step(
458
+ noise_pred, t, latent
459
+ )
460
+
461
+ latent = scheduler_step.prev_sample
462
+
463
+ if normals:
464
+ # add
465
+ # decode and normalize normal vectors
466
+ normal = self.decode_normal(latent)
467
+ normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
468
+ return normal
469
+ else:
470
+ # decode and normalize depth map
471
+ depth = self.decode_depth(latent)
472
+ depth = torch.clip(depth, -1.0, 1.0)
473
+ depth = (depth + 1.0) / 2.0
474
+ return depth
475
+
476
+
477
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
478
+ """
479
+ Encode RGB image into latent.
480
+
481
+ Args:
482
+ rgb_in (`torch.Tensor`):
483
+ Input RGB image to be encoded.
484
+
485
+ Returns:
486
+ `torch.Tensor`: Image latent.
487
+ """
488
+ # encode
489
+ h = self.vae.encoder(rgb_in)
490
+ moments = self.vae.quant_conv(h)
491
+ mean, logvar = torch.chunk(moments, 2, dim=1)
492
+ # scale latent
493
+ rgb_latent = mean * self.rgb_latent_scale_factor
494
+ return rgb_latent
495
+
496
+
497
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
498
+ """
499
+ Decode depth latent into depth map.
500
+
501
+ Args:
502
+ depth_latent (`torch.Tensor`):
503
+ Depth latent to be decoded.
504
+
505
+ Returns:
506
+ `torch.Tensor`: Decoded depth map.
507
+ """
508
+ # scale latent
509
+ depth_latent = depth_latent / self.depth_latent_scale_factor
510
+ # decode
511
+ z = self.vae.post_quant_conv(depth_latent)
512
+ stacked = self.vae.decoder(z)
513
+ # mean of output channels
514
+ depth_mean = stacked.mean(dim=1, keepdim=True)
515
+ return depth_mean
516
+
517
+ # add
518
+ def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
519
+ """
520
+ Decode normal latent into normal map.
521
+
522
+ Args:
523
+ normal_latent (`torch.Tensor`):
524
+ normal latent to be decoded.
525
+
526
+ Returns:
527
+ `torch.Tensor`: Decoded depth map.
528
+ """
529
+ # scale latent
530
+ normal_latent = normal_latent / self.depth_latent_scale_factor
531
+ # decode
532
+ z = self.vae.post_quant_conv(normal_latent)
533
+ normal = self.vae.decoder(z)
534
+ return normal
Marigold/marigold/util/__init__.py ADDED
File without changes
Marigold/marigold/util/batchsize.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import torch
22
+ import math
23
+
24
+
25
+ # Search table for suggested max. inference batch size
26
+ bs_search_table = [
27
+ # tested on A100-PCIE-80GB
28
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
29
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
30
+ # tested on A100-PCIE-40GB
31
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
32
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
33
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
34
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
35
+ # tested on RTX3090, RTX4090
36
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
37
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
38
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
39
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
40
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
41
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
42
+ # tested on GTX1080Ti
43
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
44
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
45
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
46
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
47
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
48
+ ]
49
+
50
+
51
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
52
+ """
53
+ Automatically search for suitable operating batch size.
54
+
55
+ Args:
56
+ ensemble_size (`int`):
57
+ Number of predictions to be ensembled.
58
+ input_res (`int`):
59
+ Operating resolution of the input image.
60
+
61
+ Returns:
62
+ `int`: Operating batch size.
63
+ """
64
+ if not torch.cuda.is_available():
65
+ return 1
66
+
67
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
68
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
69
+ for settings in sorted(
70
+ filtered_bs_search_table,
71
+ key=lambda k: (k["res"], -k["total_vram"]),
72
+ ):
73
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
74
+ bs = settings["bs"]
75
+ if bs > ensemble_size:
76
+ bs = ensemble_size
77
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
78
+ bs = math.ceil(ensemble_size / 2)
79
+ return bs
80
+
81
+ return 1
Marigold/marigold/util/ensemble.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from scipy.optimize import minimize
25
+
26
+
27
+ def inter_distances(tensors: torch.Tensor):
28
+ """
29
+ To calculate the distance between each two depth maps.
30
+ """
31
+ distances = []
32
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
33
+ arr1 = tensors[i : i + 1]
34
+ arr2 = tensors[j : j + 1]
35
+ distances.append(arr1 - arr2)
36
+ dist = torch.concatenate(distances, dim=0)
37
+ return dist
38
+
39
+
40
+ def ensemble_depths(
41
+ input_images: torch.Tensor,
42
+ regularizer_strength: float = 0.02,
43
+ max_iter: int = 2,
44
+ tol: float = 1e-3,
45
+ reduction: str = "median",
46
+ max_res: int = None,
47
+ ):
48
+ """
49
+ To ensemble multiple affine-invariant depth images (up to scale and shift),
50
+ by aligning estimating the scale and shift
51
+ """
52
+ device = input_images.device
53
+ dtype = input_images.dtype
54
+ np_dtype = np.float32
55
+
56
+ original_input = input_images.clone()
57
+ n_img = input_images.shape[0]
58
+ ori_shape = input_images.shape
59
+
60
+ if max_res is not None:
61
+ scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
62
+ if scale_factor < 1:
63
+ downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
64
+ input_images = downscaler(input_images)
65
+
66
+ # init guess
67
+ _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
68
+ _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
69
+ s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
70
+ t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
71
+ x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
72
+
73
+ input_images = input_images.to(device)
74
+
75
+ # objective function
76
+ def closure(x):
77
+ len_x = len(x)
78
+ s = x[: int(len_x / 2)]
79
+ t = x[int(len_x / 2) :]
80
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
81
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
82
+
83
+ transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
84
+ dists = inter_distances(transformed_arrays)
85
+ sqrt_dist = torch.sqrt(torch.mean(dists**2))
86
+
87
+ if "mean" == reduction:
88
+ pred = torch.mean(transformed_arrays, dim=0)
89
+ elif "median" == reduction:
90
+ pred = torch.median(transformed_arrays, dim=0).values
91
+ else:
92
+ raise ValueError
93
+
94
+ near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
95
+ far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
96
+
97
+ err = sqrt_dist + (near_err + far_err) * regularizer_strength
98
+ err = err.detach().cpu().numpy().astype(np_dtype)
99
+ return err
100
+
101
+ res = minimize(
102
+ closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
103
+ )
104
+ x = res.x
105
+ len_x = len(x)
106
+ s = x[: int(len_x / 2)]
107
+ t = x[int(len_x / 2) :]
108
+
109
+ # Prediction
110
+ s = torch.from_numpy(s).to(dtype=dtype).to(device)
111
+ t = torch.from_numpy(t).to(dtype=dtype).to(device)
112
+ transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
113
+ if "mean" == reduction:
114
+ aligned_images = torch.mean(transformed_arrays, dim=0)
115
+ std = torch.std(transformed_arrays, dim=0)
116
+ uncertainty = std
117
+ elif "median" == reduction:
118
+ aligned_images = torch.median(transformed_arrays, dim=0).values
119
+ # MAD (median absolute deviation) as uncertainty indicator
120
+ abs_dev = torch.abs(transformed_arrays - aligned_images)
121
+ mad = torch.median(abs_dev, dim=0).values
122
+ uncertainty = mad
123
+ else:
124
+ raise ValueError(f"Unknown reduction method: {reduction}")
125
+
126
+ # Scale and shift to [0, 1]
127
+ _min = torch.min(aligned_images)
128
+ _max = torch.max(aligned_images)
129
+ aligned_images = (aligned_images - _min) / (_max - _min)
130
+ uncertainty /= _max - _min
131
+
132
+ return aligned_images, uncertainty
Marigold/marigold/util/image_util.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-04-16
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+
22
+ import matplotlib
23
+ import numpy as np
24
+ import torch
25
+ from torchvision.transforms import InterpolationMode
26
+ from torchvision.transforms.functional import resize
27
+
28
+
29
+ def colorize_depth_maps(
30
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
31
+ ):
32
+ """
33
+ Colorize depth maps.
34
+ """
35
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
36
+
37
+ if isinstance(depth_map, torch.Tensor):
38
+ depth = depth_map.detach().squeeze().numpy()
39
+ elif isinstance(depth_map, np.ndarray):
40
+ depth = depth_map.copy().squeeze()
41
+ # reshape to [ (B,) H, W ]
42
+ if depth.ndim < 3:
43
+ depth = depth[np.newaxis, :, :]
44
+
45
+ # colorize
46
+ cm = matplotlib.colormaps[cmap]
47
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
48
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
49
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
50
+
51
+ if valid_mask is not None:
52
+ if isinstance(depth_map, torch.Tensor):
53
+ valid_mask = valid_mask.detach().numpy()
54
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
55
+ if valid_mask.ndim < 3:
56
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
57
+ else:
58
+ valid_mask = valid_mask[:, np.newaxis, :, :]
59
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
60
+ img_colored_np[~valid_mask] = 0
61
+
62
+ if isinstance(depth_map, torch.Tensor):
63
+ img_colored = torch.from_numpy(img_colored_np).float()
64
+ elif isinstance(depth_map, np.ndarray):
65
+ img_colored = img_colored_np
66
+
67
+ return img_colored
68
+
69
+
70
+ def chw2hwc(chw):
71
+ assert 3 == len(chw.shape)
72
+ if isinstance(chw, torch.Tensor):
73
+ hwc = torch.permute(chw, (1, 2, 0))
74
+ elif isinstance(chw, np.ndarray):
75
+ hwc = np.moveaxis(chw, 0, -1)
76
+ return hwc
77
+
78
+
79
+ def resize_max_res(
80
+ img: torch.Tensor,
81
+ max_edge_resolution: int,
82
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Resize image to limit maximum edge length while keeping aspect ratio.
86
+
87
+ Args:
88
+ img (`torch.Tensor`):
89
+ Image tensor to be resized.
90
+ max_edge_resolution (`int`):
91
+ Maximum edge length (pixel).
92
+ resample_method (`PIL.Image.Resampling`):
93
+ Resampling method used to resize images.
94
+
95
+ Returns:
96
+ `torch.Tensor`: Resized image.
97
+ """
98
+ assert 3 == img.dim()
99
+ _, original_height, original_width = img.shape
100
+ downscale_factor = min(
101
+ max_edge_resolution / original_width, max_edge_resolution / original_height
102
+ )
103
+
104
+ new_width = int(original_width * downscale_factor)
105
+ new_height = int(original_height * downscale_factor)
106
+
107
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
108
+ return resized_img
109
+
110
+
111
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
112
+ resample_method_dict = {
113
+ "bilinear": InterpolationMode.BILINEAR,
114
+ "bicubic": InterpolationMode.BICUBIC,
115
+ "nearest": InterpolationMode.NEAREST_EXACT,
116
+ }
117
+ resample_method = resample_method_dict.get(method_str, None)
118
+ if resample_method is None:
119
+ raise ValueError(f"Unknown resampling method: {resample_method}")
120
+ else:
121
+ return resample_method
assets/examples/mushrooms.png ADDED

Git LFS Details

  • SHA256: 249fef2c7ff127f46c6057d994c776b2127f16c6416235bd8331adda0d628642
  • Pointer size: 132 Bytes
  • Size of remote file: 5.14 MB