Spaces:
Running
on
Zero
Running
on
Zero
remove unused files
Browse files- marigold_depth_estimation_lcm.py +0 -710
- marigold_logo_square.jpg +0 -3
marigold_depth_estimation_lcm.py
DELETED
@@ -1,710 +0,0 @@
|
|
1 |
-
# Copyright 2024 Bingxin Ke, Anton Obukhov, ETH Zurich and The HuggingFace Team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# --------------------------------------------------------------------------
|
15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
-
# --------------------------------------------------------------------------
|
19 |
-
|
20 |
-
|
21 |
-
import math
|
22 |
-
from typing import Dict, Union, Tuple
|
23 |
-
|
24 |
-
import matplotlib
|
25 |
-
import numpy as np
|
26 |
-
import torch
|
27 |
-
from PIL import Image
|
28 |
-
from scipy.optimize import minimize
|
29 |
-
from torch.utils.data import DataLoader, TensorDataset
|
30 |
-
from tqdm.auto import tqdm
|
31 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
32 |
-
|
33 |
-
from diffusers import (
|
34 |
-
AutoencoderKL,
|
35 |
-
DDIMScheduler,
|
36 |
-
DiffusionPipeline,
|
37 |
-
UNet2DConditionModel,
|
38 |
-
)
|
39 |
-
from diffusers.utils import BaseOutput, check_min_version
|
40 |
-
|
41 |
-
|
42 |
-
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
43 |
-
check_min_version("0.27.0.dev0")
|
44 |
-
|
45 |
-
|
46 |
-
class MarigoldDepthConsistencyOutput(BaseOutput):
|
47 |
-
"""
|
48 |
-
Output class for Marigold monocular depth prediction pipeline.
|
49 |
-
|
50 |
-
Args:
|
51 |
-
depth_np (`np.ndarray`):
|
52 |
-
Predicted depth map, with depth values in the range of [0, 1].
|
53 |
-
depth_colored (`None` or `PIL.Image.Image`):
|
54 |
-
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
55 |
-
depth_latent (`torch.Tensor`):
|
56 |
-
Depth map's latent, with the shape of [4, h, w].
|
57 |
-
uncertainty (`None` or `np.ndarray`):
|
58 |
-
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
59 |
-
"""
|
60 |
-
|
61 |
-
depth_np: np.ndarray
|
62 |
-
depth_colored: Union[None, Image.Image]
|
63 |
-
depth_latent: torch.Tensor
|
64 |
-
uncertainty: Union[None, np.ndarray]
|
65 |
-
|
66 |
-
|
67 |
-
class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
|
68 |
-
"""
|
69 |
-
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
70 |
-
|
71 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
72 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
73 |
-
|
74 |
-
Args:
|
75 |
-
unet (`UNet2DConditionModel`):
|
76 |
-
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
77 |
-
vae (`AutoencoderKL`):
|
78 |
-
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
79 |
-
to and from latent representations.
|
80 |
-
scheduler (`DDIMScheduler`):
|
81 |
-
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
82 |
-
text_encoder (`CLIPTextModel`):
|
83 |
-
Text-encoder, for empty text embedding.
|
84 |
-
tokenizer (`CLIPTokenizer`):
|
85 |
-
CLIP tokenizer.
|
86 |
-
"""
|
87 |
-
|
88 |
-
rgb_latent_scale_factor = 0.18215
|
89 |
-
depth_latent_scale_factor = 0.18215
|
90 |
-
|
91 |
-
def __init__(
|
92 |
-
self,
|
93 |
-
unet: UNet2DConditionModel,
|
94 |
-
vae: AutoencoderKL,
|
95 |
-
scheduler: DDIMScheduler,
|
96 |
-
text_encoder: CLIPTextModel,
|
97 |
-
tokenizer: CLIPTokenizer,
|
98 |
-
):
|
99 |
-
super().__init__()
|
100 |
-
|
101 |
-
self.register_modules(
|
102 |
-
unet=unet,
|
103 |
-
vae=vae,
|
104 |
-
scheduler=scheduler,
|
105 |
-
text_encoder=text_encoder,
|
106 |
-
tokenizer=tokenizer,
|
107 |
-
)
|
108 |
-
|
109 |
-
self.empty_text_embed = None
|
110 |
-
|
111 |
-
@torch.no_grad()
|
112 |
-
def __call__(
|
113 |
-
self,
|
114 |
-
input_image: Image,
|
115 |
-
denoising_steps: int = 1,
|
116 |
-
ensemble_size: int = 1,
|
117 |
-
processing_res: int = 768,
|
118 |
-
match_input_res: bool = True,
|
119 |
-
batch_size: int = 0,
|
120 |
-
depth_latent_init: torch.Tensor = None,
|
121 |
-
depth_latent_init_strength: float = 0.1,
|
122 |
-
return_depth_latent: bool = False,
|
123 |
-
seed: int = None,
|
124 |
-
color_map: str = "Spectral",
|
125 |
-
show_progress_bar: bool = True,
|
126 |
-
ensemble_kwargs: Dict = None,
|
127 |
-
) -> MarigoldDepthConsistencyOutput:
|
128 |
-
"""
|
129 |
-
Function invoked when calling the pipeline.
|
130 |
-
|
131 |
-
Args:
|
132 |
-
input_image (`Image`):
|
133 |
-
Input RGB (or gray-scale) image.
|
134 |
-
processing_res (`int`, *optional*, defaults to `768`):
|
135 |
-
Maximum resolution of processing.
|
136 |
-
If set to 0: will not resize at all.
|
137 |
-
match_input_res (`bool`, *optional*, defaults to `True`):
|
138 |
-
Resize depth prediction to match input resolution.
|
139 |
-
Only valid if `limit_input_res` is not None.
|
140 |
-
denoising_steps (`int`, *optional*, defaults to `1`):
|
141 |
-
Number of diffusion denoising steps (consistency) during inference.
|
142 |
-
ensemble_size (`int`, *optional*, defaults to `1`):
|
143 |
-
Number of predictions to be ensembled.
|
144 |
-
batch_size (`int`, *optional*, defaults to `0`):
|
145 |
-
Inference batch size, no bigger than `num_ensemble`.
|
146 |
-
If set to 0, the script will automatically decide the proper batch size.
|
147 |
-
depth_latent_init (`torch.Tensor`, *optional*, defaults to `None`):
|
148 |
-
Initial depth map latent for better temporal consistency.
|
149 |
-
depth_latent_init_strength (`float`, *optional*, defaults to `0.1`)
|
150 |
-
Degree of initial depth latent influence, must be between 0 and 1.
|
151 |
-
return_depth_latent (`bool`, defaults to False)
|
152 |
-
Whether to return the depth latent.
|
153 |
-
seed (`int`, *optional*, defaults to `None`)
|
154 |
-
Reproducibility seed.
|
155 |
-
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
156 |
-
Display a progress bar of diffusion denoising.
|
157 |
-
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
158 |
-
Colormap used to colorize the depth map.
|
159 |
-
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
160 |
-
Arguments for detailed ensembling settings.
|
161 |
-
Returns:
|
162 |
-
`MarigoldDepthConsistencyOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
163 |
-
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
164 |
-
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
|
165 |
-
values in [0, 1]. None if `color_map` is `None`
|
166 |
-
- **depth_latent** (`torch.Tensor`) Predicted depth map latent
|
167 |
-
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
168 |
-
coming from ensembling. None if `ensemble_size = 1`
|
169 |
-
"""
|
170 |
-
|
171 |
-
device = self.device
|
172 |
-
input_size = input_image.size
|
173 |
-
|
174 |
-
if not match_input_res:
|
175 |
-
assert (
|
176 |
-
processing_res is not None
|
177 |
-
), "Value error: `resize_output_back` is only valid with "
|
178 |
-
assert processing_res >= 0, "Value error: `processing_res` must be non-negative"
|
179 |
-
assert (
|
180 |
-
1 <= denoising_steps <= 10
|
181 |
-
), "Value error: This model degrades with large number of steps"
|
182 |
-
assert ensemble_size >= 1
|
183 |
-
|
184 |
-
# ----------------- Image Preprocess -----------------
|
185 |
-
# Resize image
|
186 |
-
if processing_res > 0:
|
187 |
-
input_image = self.resize_max_res(
|
188 |
-
input_image, max_edge_resolution=processing_res
|
189 |
-
)
|
190 |
-
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
|
191 |
-
input_image = input_image.convert("RGB")
|
192 |
-
image = np.asarray(input_image)
|
193 |
-
|
194 |
-
# Normalize rgb values
|
195 |
-
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
196 |
-
rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
197 |
-
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
198 |
-
rgb_norm = rgb_norm.to(device)
|
199 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
200 |
-
|
201 |
-
# ----------------- Predicting depth -----------------
|
202 |
-
# Batch repeated input image
|
203 |
-
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
204 |
-
batch_dataset = TensorDataset(duplicated_rgb)
|
205 |
-
if batch_size > 0:
|
206 |
-
_bs = batch_size
|
207 |
-
else:
|
208 |
-
_bs = self._find_batch_size(
|
209 |
-
ensemble_size=ensemble_size,
|
210 |
-
input_res=max(duplicated_rgb.shape[-2:]),
|
211 |
-
dtype=self.dtype,
|
212 |
-
)
|
213 |
-
|
214 |
-
batch_loader = DataLoader(batch_dataset, batch_size=_bs, shuffle=False)
|
215 |
-
|
216 |
-
# Predict depth maps (batched)
|
217 |
-
depth_pred_ls = []
|
218 |
-
if show_progress_bar:
|
219 |
-
iterable = tqdm(
|
220 |
-
batch_loader, desc=" " * 2 + "Inference batches", leave=False
|
221 |
-
)
|
222 |
-
else:
|
223 |
-
iterable = batch_loader
|
224 |
-
depth_latent = None
|
225 |
-
for batch in iterable:
|
226 |
-
(batched_img,) = batch
|
227 |
-
depth_pred_raw, depth_latent = self.single_infer(
|
228 |
-
rgb_in=batched_img,
|
229 |
-
num_inference_steps=denoising_steps,
|
230 |
-
depth_latent_init=depth_latent_init,
|
231 |
-
depth_latent_init_strength=depth_latent_init_strength,
|
232 |
-
seed=seed,
|
233 |
-
show_pbar=show_progress_bar,
|
234 |
-
)
|
235 |
-
depth_pred_ls.append(depth_pred_raw.detach())
|
236 |
-
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
|
237 |
-
torch.cuda.empty_cache() # clear vram cache for ensembling
|
238 |
-
|
239 |
-
# ----------------- Test-time ensembling -----------------
|
240 |
-
if ensemble_size > 1:
|
241 |
-
depth_pred, pred_uncert = self.ensemble_depths(
|
242 |
-
depth_preds, **(ensemble_kwargs or {})
|
243 |
-
)
|
244 |
-
else:
|
245 |
-
depth_pred = depth_preds
|
246 |
-
pred_uncert = None
|
247 |
-
|
248 |
-
# ----------------- Post processing -----------------
|
249 |
-
# Scale prediction to [0, 1]
|
250 |
-
min_d = torch.min(depth_pred)
|
251 |
-
max_d = torch.max(depth_pred)
|
252 |
-
depth_pred = (depth_pred - min_d) / (max_d - min_d)
|
253 |
-
if return_depth_latent:
|
254 |
-
if ensemble_size > 1:
|
255 |
-
depth_latent = self._encode_depth(2 * depth_pred - 1)
|
256 |
-
else:
|
257 |
-
depth_latent = None
|
258 |
-
|
259 |
-
# Convert to numpy
|
260 |
-
depth_pred = depth_pred.cpu().numpy().astype(np.float32)
|
261 |
-
|
262 |
-
# Resize back to original resolution
|
263 |
-
if match_input_res:
|
264 |
-
pred_img = Image.fromarray(depth_pred)
|
265 |
-
pred_img = pred_img.resize(input_size)
|
266 |
-
depth_pred = np.asarray(pred_img)
|
267 |
-
|
268 |
-
# Clip output range
|
269 |
-
depth_pred = depth_pred.clip(0, 1)
|
270 |
-
|
271 |
-
# Colorize
|
272 |
-
if color_map is not None:
|
273 |
-
depth_colored = self.colorize_depth_maps(
|
274 |
-
depth_pred, 0, 1, cmap=color_map
|
275 |
-
).squeeze() # [3, H, W], value in (0, 1)
|
276 |
-
depth_colored = (depth_colored * 255).astype(np.uint8)
|
277 |
-
depth_colored_hwc = self.chw2hwc(depth_colored)
|
278 |
-
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
279 |
-
else:
|
280 |
-
depth_colored_img = None
|
281 |
-
return MarigoldDepthConsistencyOutput(
|
282 |
-
depth_np=depth_pred,
|
283 |
-
depth_colored=depth_colored_img,
|
284 |
-
depth_latent=depth_latent,
|
285 |
-
uncertainty=pred_uncert,
|
286 |
-
)
|
287 |
-
|
288 |
-
def _encode_empty_text(self):
|
289 |
-
"""
|
290 |
-
Encode text embedding for empty prompt.
|
291 |
-
"""
|
292 |
-
prompt = ""
|
293 |
-
text_inputs = self.tokenizer(
|
294 |
-
prompt,
|
295 |
-
padding="do_not_pad",
|
296 |
-
max_length=self.tokenizer.model_max_length,
|
297 |
-
truncation=True,
|
298 |
-
return_tensors="pt",
|
299 |
-
)
|
300 |
-
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
301 |
-
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
302 |
-
|
303 |
-
@torch.no_grad()
|
304 |
-
def single_infer(
|
305 |
-
self,
|
306 |
-
rgb_in: torch.Tensor,
|
307 |
-
num_inference_steps: int,
|
308 |
-
depth_latent_init: torch.Tensor,
|
309 |
-
depth_latent_init_strength: float,
|
310 |
-
seed: int,
|
311 |
-
show_pbar: bool,
|
312 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
313 |
-
"""
|
314 |
-
Perform an individual depth prediction without ensembling.
|
315 |
-
|
316 |
-
Args:
|
317 |
-
rgb_in (`torch.Tensor`):
|
318 |
-
Input RGB image.
|
319 |
-
num_inference_steps (`int`):
|
320 |
-
Number of diffusion denoisign steps (DDIM) during inference.
|
321 |
-
depth_latent_init (`torch.Tensor`, `optional`):
|
322 |
-
Initial depth latent
|
323 |
-
depth_latent_init_strength (`float`, `optional`):
|
324 |
-
Degree of initial depth latent influence, must be between 0 and 1
|
325 |
-
seed (`int`, *optional*, defaults to `None`)
|
326 |
-
Reproducibility seed.
|
327 |
-
show_pbar (`bool`):
|
328 |
-
Display a progress bar of diffusion denoising.
|
329 |
-
Returns:
|
330 |
-
`torch.Tensor`: Predicted depth map.
|
331 |
-
"""
|
332 |
-
device = rgb_in.device
|
333 |
-
|
334 |
-
# Set timesteps
|
335 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
336 |
-
timesteps = self.scheduler.timesteps # [T]
|
337 |
-
|
338 |
-
# Encode image
|
339 |
-
rgb_latent = self._encode_rgb(rgb_in)
|
340 |
-
|
341 |
-
# Initial depth map (noise)
|
342 |
-
if seed is None:
|
343 |
-
rng = None
|
344 |
-
else:
|
345 |
-
rng = torch.Generator(device=device)
|
346 |
-
rng.manual_seed(seed)
|
347 |
-
depth_latent = torch.randn(
|
348 |
-
rgb_latent.shape, device=device, dtype=self.dtype, generator=rng
|
349 |
-
) # [B, 4, h, w]
|
350 |
-
|
351 |
-
if depth_latent_init is not None:
|
352 |
-
assert 0.0 <= depth_latent_init_strength <= 1.0
|
353 |
-
assert (
|
354 |
-
depth_latent_init.dim() == 4
|
355 |
-
and depth_latent.dim() == 4
|
356 |
-
and depth_latent_init.shape[0] == 1
|
357 |
-
)
|
358 |
-
if depth_latent.shape[0] != 1:
|
359 |
-
depth_latent_init = depth_latent_init.repeat(
|
360 |
-
depth_latent.shape[0], 1, 1, 1
|
361 |
-
)
|
362 |
-
depth_latent *= 1.0 - depth_latent_init_strength
|
363 |
-
depth_latent = depth_latent + depth_latent_init * depth_latent_init_strength
|
364 |
-
|
365 |
-
# Batched empty text embedding
|
366 |
-
if self.empty_text_embed is None:
|
367 |
-
self._encode_empty_text()
|
368 |
-
batch_empty_text_embed = self.empty_text_embed.repeat(
|
369 |
-
(rgb_latent.shape[0], 1, 1)
|
370 |
-
) # [B, 2, 1024]
|
371 |
-
|
372 |
-
# Denoising loop
|
373 |
-
if show_pbar:
|
374 |
-
iterable = tqdm(
|
375 |
-
enumerate(timesteps),
|
376 |
-
total=len(timesteps),
|
377 |
-
leave=False,
|
378 |
-
desc=" " * 4 + "Diffusion denoising",
|
379 |
-
)
|
380 |
-
else:
|
381 |
-
iterable = enumerate(timesteps)
|
382 |
-
|
383 |
-
for i, t in iterable:
|
384 |
-
unet_input = torch.cat(
|
385 |
-
[rgb_latent, depth_latent], dim=1
|
386 |
-
) # this order is important
|
387 |
-
|
388 |
-
# predict the noise residual
|
389 |
-
noise_pred = self.unet(
|
390 |
-
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
391 |
-
).sample # [B, 4, h, w]
|
392 |
-
|
393 |
-
# compute the previous noisy sample x_t -> x_t-1
|
394 |
-
depth_latent = self.scheduler.step(
|
395 |
-
noise_pred, t, depth_latent, generator=rng
|
396 |
-
).prev_sample
|
397 |
-
|
398 |
-
depth = self._decode_depth(depth_latent)
|
399 |
-
|
400 |
-
# clip prediction
|
401 |
-
depth = torch.clip(depth, -1.0, 1.0)
|
402 |
-
# shift to [0, 1]
|
403 |
-
depth = (depth + 1.0) / 2.0
|
404 |
-
|
405 |
-
return depth, depth_latent
|
406 |
-
|
407 |
-
def _encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor:
|
408 |
-
"""
|
409 |
-
Encode depth image into latent.
|
410 |
-
|
411 |
-
Args:
|
412 |
-
depth_in (`torch.Tensor`):
|
413 |
-
Input Depth image to be encoded.
|
414 |
-
|
415 |
-
Returns:
|
416 |
-
`torch.Tensor`: Depth latent.
|
417 |
-
"""
|
418 |
-
# encode
|
419 |
-
dims = depth_in.squeeze().shape
|
420 |
-
h = self.vae.encoder(depth_in.reshape(1, 1, *dims).repeat(1, 3, 1, 1))
|
421 |
-
moments = self.vae.quant_conv(h)
|
422 |
-
mean, _ = torch.chunk(moments, 2, dim=1)
|
423 |
-
depth_latent = mean * self.depth_latent_scale_factor
|
424 |
-
return depth_latent
|
425 |
-
|
426 |
-
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
427 |
-
"""
|
428 |
-
Encode RGB image into latent.
|
429 |
-
|
430 |
-
Args:
|
431 |
-
rgb_in (`torch.Tensor`):
|
432 |
-
Input RGB image to be encoded.
|
433 |
-
|
434 |
-
Returns:
|
435 |
-
`torch.Tensor`: Image latent.
|
436 |
-
"""
|
437 |
-
# encode
|
438 |
-
h = self.vae.encoder(rgb_in)
|
439 |
-
moments = self.vae.quant_conv(h)
|
440 |
-
mean, logvar = torch.chunk(moments, 2, dim=1)
|
441 |
-
# scale latent
|
442 |
-
rgb_latent = mean * self.rgb_latent_scale_factor
|
443 |
-
return rgb_latent
|
444 |
-
|
445 |
-
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
446 |
-
"""
|
447 |
-
Decode depth latent into depth map.
|
448 |
-
|
449 |
-
Args:
|
450 |
-
depth_latent (`torch.Tensor`):
|
451 |
-
Depth latent to be decoded.
|
452 |
-
|
453 |
-
Returns:
|
454 |
-
`torch.Tensor`: Decoded depth map.
|
455 |
-
"""
|
456 |
-
# scale latent
|
457 |
-
depth_latent = depth_latent / self.depth_latent_scale_factor
|
458 |
-
# decode
|
459 |
-
z = self.vae.post_quant_conv(depth_latent)
|
460 |
-
stacked = self.vae.decoder(z)
|
461 |
-
# mean of output channels
|
462 |
-
depth_mean = stacked.mean(dim=1, keepdim=True)
|
463 |
-
return depth_mean
|
464 |
-
|
465 |
-
@staticmethod
|
466 |
-
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
467 |
-
"""
|
468 |
-
Resize image to limit maximum edge length while keeping aspect ratio.
|
469 |
-
|
470 |
-
Args:
|
471 |
-
img (`Image.Image`):
|
472 |
-
Image to be resized.
|
473 |
-
max_edge_resolution (`int`):
|
474 |
-
Maximum edge length (pixel).
|
475 |
-
|
476 |
-
Returns:
|
477 |
-
`Image.Image`: Resized image.
|
478 |
-
"""
|
479 |
-
original_width, original_height = img.size
|
480 |
-
downscale_factor = min(
|
481 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
482 |
-
)
|
483 |
-
|
484 |
-
new_width = int(original_width * downscale_factor)
|
485 |
-
new_height = int(original_height * downscale_factor)
|
486 |
-
|
487 |
-
resized_img = img.resize((new_width, new_height))
|
488 |
-
return resized_img
|
489 |
-
|
490 |
-
@staticmethod
|
491 |
-
def colorize_depth_maps(
|
492 |
-
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
493 |
-
):
|
494 |
-
"""
|
495 |
-
Colorize depth maps.
|
496 |
-
"""
|
497 |
-
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
498 |
-
|
499 |
-
if isinstance(depth_map, torch.Tensor):
|
500 |
-
depth = depth_map.detach().squeeze().numpy()
|
501 |
-
elif isinstance(depth_map, np.ndarray):
|
502 |
-
depth = depth_map.copy().squeeze()
|
503 |
-
# reshape to [ (B,) H, W ]
|
504 |
-
if depth.ndim < 3:
|
505 |
-
depth = depth[np.newaxis, :, :]
|
506 |
-
|
507 |
-
# colorize
|
508 |
-
cm = matplotlib.colormaps[cmap]
|
509 |
-
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
510 |
-
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
511 |
-
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
512 |
-
|
513 |
-
if valid_mask is not None:
|
514 |
-
if isinstance(depth_map, torch.Tensor):
|
515 |
-
valid_mask = valid_mask.detach().numpy()
|
516 |
-
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
517 |
-
if valid_mask.ndim < 3:
|
518 |
-
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
519 |
-
else:
|
520 |
-
valid_mask = valid_mask[:, np.newaxis, :, :]
|
521 |
-
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
522 |
-
img_colored_np[~valid_mask] = 0
|
523 |
-
|
524 |
-
if isinstance(depth_map, torch.Tensor):
|
525 |
-
img_colored = torch.from_numpy(img_colored_np).float()
|
526 |
-
elif isinstance(depth_map, np.ndarray):
|
527 |
-
img_colored = img_colored_np
|
528 |
-
|
529 |
-
return img_colored
|
530 |
-
|
531 |
-
@staticmethod
|
532 |
-
def chw2hwc(chw):
|
533 |
-
assert 3 == len(chw.shape)
|
534 |
-
if isinstance(chw, torch.Tensor):
|
535 |
-
hwc = torch.permute(chw, (1, 2, 0))
|
536 |
-
elif isinstance(chw, np.ndarray):
|
537 |
-
hwc = np.moveaxis(chw, 0, -1)
|
538 |
-
return hwc
|
539 |
-
|
540 |
-
@staticmethod
|
541 |
-
def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
542 |
-
"""
|
543 |
-
Automatically search for suitable operating batch size.
|
544 |
-
|
545 |
-
Args:
|
546 |
-
ensemble_size (`int`):
|
547 |
-
Number of predictions to be ensembled.
|
548 |
-
input_res (`int`):
|
549 |
-
Operating resolution of the input image.
|
550 |
-
|
551 |
-
Returns:
|
552 |
-
`int`: Operating batch size.
|
553 |
-
"""
|
554 |
-
# Search table for suggested max. inference batch size
|
555 |
-
bs_search_table = [
|
556 |
-
# tested on A100-PCIE-80GB
|
557 |
-
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
558 |
-
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
559 |
-
# tested on A100-PCIE-40GB
|
560 |
-
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
561 |
-
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
562 |
-
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
563 |
-
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
564 |
-
# tested on RTX3090, RTX4090
|
565 |
-
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
566 |
-
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
567 |
-
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
568 |
-
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
569 |
-
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
570 |
-
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
571 |
-
# tested on GTX1080Ti
|
572 |
-
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
573 |
-
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
574 |
-
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
575 |
-
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
576 |
-
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
577 |
-
]
|
578 |
-
|
579 |
-
if not torch.cuda.is_available():
|
580 |
-
return 1
|
581 |
-
|
582 |
-
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
583 |
-
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
584 |
-
for settings in sorted(
|
585 |
-
filtered_bs_search_table,
|
586 |
-
key=lambda k: (k["res"], -k["total_vram"]),
|
587 |
-
):
|
588 |
-
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
589 |
-
bs = settings["bs"]
|
590 |
-
if bs > ensemble_size:
|
591 |
-
bs = ensemble_size
|
592 |
-
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
593 |
-
bs = math.ceil(ensemble_size / 2)
|
594 |
-
return bs
|
595 |
-
|
596 |
-
return 1
|
597 |
-
|
598 |
-
@staticmethod
|
599 |
-
def ensemble_depths(
|
600 |
-
input_images: torch.Tensor,
|
601 |
-
regularizer_strength: float = 0.02,
|
602 |
-
max_iter: int = 2,
|
603 |
-
tol: float = 1e-3,
|
604 |
-
reduction: str = "median",
|
605 |
-
max_res: int = None,
|
606 |
-
):
|
607 |
-
"""
|
608 |
-
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
609 |
-
by aligning estimating the scale and shift
|
610 |
-
"""
|
611 |
-
|
612 |
-
def inter_distances(tensors: torch.Tensor):
|
613 |
-
"""
|
614 |
-
To calculate the distance between each two depth maps.
|
615 |
-
"""
|
616 |
-
distances = []
|
617 |
-
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
618 |
-
arr1 = tensors[i : i + 1]
|
619 |
-
arr2 = tensors[j : j + 1]
|
620 |
-
distances.append(arr1 - arr2)
|
621 |
-
dist = torch.concatenate(distances, dim=0)
|
622 |
-
return dist
|
623 |
-
|
624 |
-
device = input_images.device
|
625 |
-
dtype = input_images.dtype
|
626 |
-
np_dtype = np.float32
|
627 |
-
|
628 |
-
original_input = input_images.clone()
|
629 |
-
n_img = input_images.shape[0]
|
630 |
-
ori_shape = input_images.shape
|
631 |
-
|
632 |
-
if max_res is not None:
|
633 |
-
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
634 |
-
if scale_factor < 1:
|
635 |
-
downscaler = torch.nn.Upsample(
|
636 |
-
scale_factor=scale_factor, mode="nearest"
|
637 |
-
)
|
638 |
-
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
639 |
-
|
640 |
-
# init guess
|
641 |
-
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
642 |
-
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
643 |
-
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
644 |
-
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
645 |
-
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
646 |
-
|
647 |
-
input_images = input_images.to(device)
|
648 |
-
|
649 |
-
# objective function
|
650 |
-
def closure(x):
|
651 |
-
l = len(x)
|
652 |
-
s = x[: int(l / 2)]
|
653 |
-
t = x[int(l / 2) :]
|
654 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
655 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
656 |
-
|
657 |
-
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
658 |
-
dists = inter_distances(transformed_arrays)
|
659 |
-
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
660 |
-
|
661 |
-
if "mean" == reduction:
|
662 |
-
pred = torch.mean(transformed_arrays, dim=0)
|
663 |
-
elif "median" == reduction:
|
664 |
-
pred = torch.median(transformed_arrays, dim=0).values
|
665 |
-
else:
|
666 |
-
raise ValueError
|
667 |
-
|
668 |
-
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
669 |
-
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
670 |
-
|
671 |
-
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
672 |
-
err = err.detach().cpu().numpy().astype(np_dtype)
|
673 |
-
return err
|
674 |
-
|
675 |
-
res = minimize(
|
676 |
-
closure,
|
677 |
-
x,
|
678 |
-
method="BFGS",
|
679 |
-
tol=tol,
|
680 |
-
options={"maxiter": max_iter, "disp": False},
|
681 |
-
)
|
682 |
-
x = res.x
|
683 |
-
l = len(x)
|
684 |
-
s = x[: int(l / 2)]
|
685 |
-
t = x[int(l / 2) :]
|
686 |
-
|
687 |
-
# Prediction
|
688 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
689 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
690 |
-
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
691 |
-
if "mean" == reduction:
|
692 |
-
aligned_images = torch.mean(transformed_arrays, dim=0)
|
693 |
-
std = torch.std(transformed_arrays, dim=0)
|
694 |
-
uncertainty = std
|
695 |
-
elif "median" == reduction:
|
696 |
-
aligned_images = torch.median(transformed_arrays, dim=0).values
|
697 |
-
# MAD (median absolute deviation) as uncertainty indicator
|
698 |
-
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
699 |
-
mad = torch.median(abs_dev, dim=0).values
|
700 |
-
uncertainty = mad
|
701 |
-
else:
|
702 |
-
raise ValueError(f"Unknown reduction method: {reduction}")
|
703 |
-
|
704 |
-
# Scale and shift to [0, 1]
|
705 |
-
_min = torch.min(aligned_images)
|
706 |
-
_max = torch.max(aligned_images)
|
707 |
-
aligned_images = (aligned_images - _min) / (_max - _min)
|
708 |
-
uncertainty /= _max - _min
|
709 |
-
|
710 |
-
return aligned_images, uncertainty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marigold_logo_square.jpg
DELETED
Git LFS Details
|