developy commited on
Commit
7f10cdc
·
verified ·
1 Parent(s): a7638ff

Upload 21 files

Browse files
apdepth/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
apdepth/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (216 Bytes). View file
 
apdepth/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (229 Bytes). View file
 
apdepth/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (214 Bytes). View file
 
apdepth/__pycache__/marigold_pipeline.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
apdepth/__pycache__/marigold_pipeline.cpython-312.pyc ADDED
Binary file (19.9 kB). View file
 
apdepth/__pycache__/marigold_pipeline.cpython-38.pyc ADDED
Binary file (15 kB). View file
 
apdepth/marigold_pipeline.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+
22
+ import logging
23
+ from typing import Dict, Optional, Union
24
+
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from diffusers import (
30
+ AutoencoderKL,
31
+ DDIMScheduler,
32
+ DiffusionPipeline,
33
+ LCMScheduler,
34
+ # UNet2DConditionModel,
35
+ )
36
+ from marigold.modules.unet_2d_condition import UNet2DConditionModel
37
+ from diffusers.utils import BaseOutput
38
+ from PIL import Image
39
+ from torch.utils.data import DataLoader, TensorDataset
40
+ from torchvision.transforms import InterpolationMode
41
+ from torchvision.transforms.functional import pil_to_tensor, resize
42
+ from tqdm.auto import tqdm
43
+ from transformers import CLIPTextModel, CLIPTokenizer
44
+
45
+ from .util.batchsize import find_batch_size
46
+ from .util.ensemble import ensemble_depth
47
+ from .util.image_util import (
48
+ chw2hwc,
49
+ colorize_depth_maps,
50
+ get_tv_resample_method,
51
+ resize_max_res,
52
+ )
53
+
54
+
55
+ class MarigoldDepthOutput(BaseOutput):
56
+ """
57
+ Output class for Marigold monocular depth prediction pipeline.
58
+
59
+ Args:
60
+ depth_np (`np.ndarray`):
61
+ Predicted depth map, with depth values in the range of [0, 1].
62
+ depth_colored (`PIL.Image.Image`):
63
+ Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
64
+ uncertainty (`None` or `np.ndarray`):
65
+ Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
66
+ """
67
+
68
+ depth_np: np.ndarray
69
+ depth_colored: Union[None, Image.Image]
70
+ uncertainty: Union[None, np.ndarray]
71
+
72
+
73
+ class MarigoldPipeline(DiffusionPipeline):
74
+ """
75
+ Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
76
+
77
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
78
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
79
+
80
+ Args:
81
+ unet (`UNet2DConditionModel`):
82
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
83
+ vae (`AutoencoderKL`):
84
+ Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
85
+ to and from latent representations.
86
+ scheduler (`DDIMScheduler`):
87
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
88
+ text_encoder (`CLIPTextModel`):
89
+ Text-encoder, for empty text embedding.
90
+ tokenizer (`CLIPTokenizer`):
91
+ CLIP tokenizer.
92
+ scale_invariant (`bool`, *optional*):
93
+ A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
94
+ the model config. When used together with the `shift_invariant=True` flag, the model is also called
95
+ "affine-invariant". NB: overriding this value is not supported.
96
+ shift_invariant (`bool`, *optional*):
97
+ A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
98
+ the model config. When used together with the `scale_invariant=True` flag, the model is also called
99
+ "affine-invariant". NB: overriding this value is not supported.
100
+ default_denoising_steps (`int`, *optional*):
101
+ The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
102
+ quality with the given model. This value must be set in the model config. When the pipeline is called
103
+ without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
104
+ reasonable results with various model flavors compatible with the pipeline, such as those relying on very
105
+ short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
106
+ default_processing_resolution (`int`, *optional*):
107
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
108
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
109
+ default value is used. This is required to ensure reasonable results with various model flavors trained
110
+ with varying optimal processing resolution values.
111
+ """
112
+
113
+ rgb_latent_scale_factor = 0.18215
114
+ depth_latent_scale_factor = 0.18215
115
+
116
+ def __init__(
117
+ self,
118
+ unet: UNet2DConditionModel,
119
+ vae: AutoencoderKL,
120
+ scheduler: Union[DDIMScheduler, LCMScheduler],
121
+ text_encoder: CLIPTextModel,
122
+ tokenizer: CLIPTokenizer,
123
+ scale_invariant: Optional[bool] = True,
124
+ shift_invariant: Optional[bool] = True,
125
+ default_denoising_steps: Optional[int] = None,
126
+ default_processing_resolution: Optional[int] = None,
127
+ ):
128
+ super().__init__()
129
+ self.register_modules(
130
+ unet=unet,
131
+ vae=vae,
132
+ scheduler=scheduler,
133
+ text_encoder=text_encoder,
134
+ tokenizer=tokenizer,
135
+ )
136
+ self.register_to_config(
137
+ scale_invariant=scale_invariant,
138
+ shift_invariant=shift_invariant,
139
+ default_denoising_steps=default_denoising_steps,
140
+ default_processing_resolution=default_processing_resolution,
141
+ )
142
+
143
+ self.scale_invariant = scale_invariant
144
+ self.shift_invariant = shift_invariant
145
+ self.default_denoising_steps = default_denoising_steps
146
+ self.default_processing_resolution = default_processing_resolution
147
+
148
+ self.empty_text_embed = None
149
+
150
+ self._fft_masks = {}
151
+
152
+ @torch.no_grad()
153
+ def __call__(
154
+ self,
155
+ input_image: Union[Image.Image, torch.Tensor],
156
+ denoising_steps: Optional[int] = None,
157
+ ensemble_size: int = 5,
158
+ processing_res: Optional[int] = None,
159
+ match_input_res: bool = True,
160
+ resample_method: str = "bilinear",
161
+ batch_size: int = 0,
162
+ color_map: str = "Spectral",
163
+ show_progress_bar: bool = True,
164
+ ensemble_kwargs: Dict = None,
165
+ ) -> MarigoldDepthOutput:
166
+ """
167
+ Function invoked when calling the pipeline.
168
+
169
+ Args:
170
+ input_image (`Image`):
171
+ Input RGB (or gray-scale) image.
172
+ denoising_steps (`int`, *optional*, defaults to `None`):
173
+ Number of denoising diffusion steps during inference. The default value `None` results in automatic
174
+ selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4
175
+ for Marigold-LCM models.
176
+ ensemble_size (`int`, *optional*, defaults to `10`):
177
+ Number of predictions to be ensembled.
178
+ processing_res (`int`, *optional*, defaults to `None`):
179
+ Effective processing resolution. When set to `0`, processes at the original image resolution. This
180
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
181
+ value `None` resolves to the optimal value from the model config.
182
+ match_input_res (`bool`, *optional*, defaults to `True`):
183
+ Resize depth prediction to match input resolution.
184
+ Only valid if `processing_res` > 0.
185
+ resample_method: (`str`, *optional*, defaults to `bilinear`):
186
+ Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
187
+ batch_size (`int`, *optional*, defaults to `0`):
188
+ Inference batch size, no bigger than `num_ensemble`.
189
+ If set to 0, the script will automatically decide the proper batch size.
190
+ generator (`torch.Generator`, *optional*, defaults to `None`)
191
+ Random generator for initial noise generation.
192
+ show_progress_bar (`bool`, *optional*, defaults to `True`):
193
+ Display a progress bar of diffusion denoising.
194
+ color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
195
+ Colormap used to colorize the depth map.
196
+ scale_invariant (`str`, *optional*, defaults to `True`):
197
+ Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
198
+ shift_invariant (`str`, *optional*, defaults to `True`):
199
+ Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m.
200
+ ensemble_kwargs (`dict`, *optional*, defaults to `None`):
201
+ Arguments for detailed ensembling settings.
202
+ Returns:
203
+ `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
204
+ - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
205
+ - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
206
+ - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
207
+ coming from ensembling. None if `ensemble_size = 1`
208
+ """
209
+ # Model-specific optimal default values leading to fast and reasonable results.
210
+ if processing_res is None:
211
+ processing_res = self.default_processing_resolution
212
+
213
+ assert processing_res >= 0
214
+
215
+ # Check if denoising step is reasonable
216
+ self._check_inference_step(denoising_steps)
217
+
218
+ resample_method: InterpolationMode = get_tv_resample_method(resample_method)
219
+
220
+ # ----------------- Image Preprocess -----------------
221
+ # Convert to torch tensor
222
+ if isinstance(input_image, Image.Image):
223
+ input_image = input_image.convert("RGB")
224
+ # convert to torch tensor [H, W, rgb] -> [rgb, H, W]
225
+ rgb = pil_to_tensor(input_image)
226
+ rgb = rgb.unsqueeze(0) # [1, rgb, H, W]
227
+ elif isinstance(input_image, torch.Tensor):
228
+ rgb = input_image
229
+ else:
230
+ raise TypeError(f"Unknown input type: {type(input_image) = }")
231
+ input_size = rgb.shape
232
+ assert (
233
+ 4 == rgb.dim() and 3 == input_size[-3]
234
+ ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]"
235
+
236
+ # Resize image
237
+ if processing_res > 0:
238
+ rgb = resize_max_res(
239
+ rgb,
240
+ max_edge_resolution=processing_res,
241
+ resample_method=resample_method,
242
+ )
243
+
244
+ # Normalize rgb values
245
+ rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
246
+ rgb_norm = rgb_norm.to(self.dtype)
247
+ assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
248
+
249
+ # ----------------- Predicting depth -----------------
250
+ # Batch repeated input image
251
+ duplicated_rgb = rgb_norm.expand(1, -1, -1, -1)
252
+ single_rgb_dataset = TensorDataset(duplicated_rgb)
253
+ if batch_size > 0:
254
+ _bs = batch_size
255
+ else:
256
+ _bs = 1
257
+
258
+ single_rgb_loader = DataLoader(
259
+ single_rgb_dataset, batch_size=_bs, shuffle=False
260
+ )
261
+
262
+ # Predict depth maps (batched)
263
+ depth_pred_ls = []
264
+ if show_progress_bar:
265
+ iterable = tqdm(
266
+ single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
267
+ )
268
+ else:
269
+ iterable = single_rgb_loader
270
+ for batch in iterable:
271
+ (batched_img,) = batch # here the image is still around 0-1
272
+ depth_pred_raw = self.single_infer(
273
+ rgb_in=batched_img,
274
+ )
275
+ depth_pred_ls.append(depth_pred_raw.detach())
276
+ depth_preds = torch.concat(depth_pred_ls, dim=0)
277
+ torch.cuda.empty_cache() # clear vram cache for ensembling
278
+
279
+ # ----------------- Test-time ensembling -----------------
280
+ if ensemble_size > 1:
281
+ depth_pred, pred_uncert = ensemble_depth(
282
+ depth_preds,
283
+ scale_invariant=self.scale_invariant,
284
+ shift_invariant=self.shift_invariant,
285
+ max_res=50,
286
+ **(ensemble_kwargs or {}),
287
+ )
288
+ else:
289
+ depth_pred = depth_preds
290
+ pred_uncert = None
291
+
292
+ # Resize back to original resolution
293
+ if match_input_res:
294
+ depth_pred = resize(
295
+ depth_pred,
296
+ input_size[-2:],
297
+ interpolation=resample_method,
298
+ antialias=True,
299
+ )
300
+
301
+ # Convert to numpy
302
+ depth_pred = depth_pred.squeeze()
303
+ depth_pred = depth_pred.cpu().numpy()
304
+ if pred_uncert is not None:
305
+ pred_uncert = pred_uncert.squeeze().cpu().numpy()
306
+
307
+ # Clip output range
308
+ depth_pred = depth_pred.clip(0, 1)
309
+
310
+ # Colorize
311
+ if color_map is not None:
312
+ depth_colored = colorize_depth_maps(
313
+ depth_pred, 0, 1, cmap=color_map
314
+ ).squeeze() # [3, H, W], value in (0, 1)
315
+ depth_colored = (depth_colored * 255).astype(np.uint8)
316
+ depth_colored_hwc = chw2hwc(depth_colored)
317
+ depth_colored_img = Image.fromarray(depth_colored_hwc)
318
+ else:
319
+ depth_colored_img = None
320
+
321
+ return MarigoldDepthOutput(
322
+ depth_np=depth_pred,
323
+ depth_colored=depth_colored_img,
324
+ uncertainty=pred_uncert,
325
+ )
326
+
327
+ def _check_inference_step(self, n_step: int) -> None:
328
+ """
329
+ Check if denoising step is reasonable
330
+ Args:
331
+ n_step (`int`): denoising steps
332
+ """
333
+ assert n_step >= 1
334
+
335
+ if isinstance(self.scheduler, DDIMScheduler):
336
+ if n_step < 10:
337
+ logging.warning(
338
+ f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
339
+ )
340
+ elif isinstance(self.scheduler, LCMScheduler):
341
+ if not 1 <= n_step <= 4:
342
+ logging.warning(
343
+ f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps."
344
+ )
345
+ else:
346
+ raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
347
+
348
+ def encode_empty_text(self):
349
+ """
350
+ Encode text embedding for empty prompt
351
+ """
352
+ prompt = ""
353
+ text_inputs = self.tokenizer(
354
+ prompt,
355
+ padding="do_not_pad",
356
+ max_length=self.tokenizer.model_max_length,
357
+ truncation=True,
358
+ return_tensors="pt",
359
+ )
360
+ text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
361
+ self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
362
+
363
+ @torch.no_grad()
364
+ def _get_highpass_mask(self, H, W, radius, device):
365
+ key = (H, W, radius, device)
366
+ if key not in self._fft_masks:
367
+ yy, xx = torch.meshgrid(torch.arange(H, device=device),
368
+ torch.arange(W, device=device),
369
+ indexing="ij")
370
+ yy = yy - H // 2
371
+ xx = xx - W // 2
372
+ mask_low = (xx**2 + yy**2 <= radius**2).float()
373
+ mask_high = 1 - mask_low
374
+ mask_high = mask_high[None, None, :, :]
375
+ self._fft_masks[key] = mask_high
376
+ return self._fft_masks[key]
377
+
378
+ @torch.no_grad()
379
+ def rgb_fft(self, x: torch.Tensor, highpass_radius: int = 30):
380
+ B, C, H, W = x.shape
381
+ device = x.device
382
+
383
+ f = torch.fft.fft2(x, norm="ortho")
384
+ fshift = torch.fft.fftshift(f)
385
+
386
+ mask_high = self._get_highpass_mask(H, W, highpass_radius, device)
387
+ fshift_high = fshift * mask_high
388
+
389
+ f_ishift = torch.fft.ifftshift(fshift_high)
390
+ img_high = torch.fft.ifft2(f_ishift, norm="ortho").real
391
+ return img_high
392
+
393
+ @torch.no_grad()
394
+ def single_infer(
395
+ self,
396
+ rgb_in: torch.Tensor,
397
+ ) -> torch.Tensor:
398
+ """
399
+ Perform an individual depth prediction without ensembling.
400
+
401
+ Args:
402
+ rgb_in (`torch.Tensor`):
403
+ Input RGB image.
404
+ num_inference_steps (`int`):
405
+ Number of diffusion denoisign steps (DDIM) during inference.
406
+ show_pbar (`bool`):
407
+ Display a progress bar of diffusion denoising.
408
+ generator (`torch.Generator`)
409
+ Random generator for initial noise generation.
410
+ Returns:
411
+ `torch.Tensor`: Predicted depth map.
412
+ """
413
+ device = self.device
414
+ rgb_in = rgb_in.to(device)
415
+ rgb_fft = self.rgb_fft(rgb_in)
416
+
417
+ # Encode image
418
+ rgb_latent = self.encode_rgb(rgb_in)
419
+ # rgb_fft_latent = self.encode_rgb(rgb_fft)
420
+
421
+ # Batched empty text embedding
422
+ if self.empty_text_embed is None:
423
+ self.encode_empty_text()
424
+ batch_empty_text_embed = self.empty_text_embed.repeat(
425
+ (rgb_latent.shape[0], 1, 1)
426
+ ).to(device) # [B, 2, 1024]
427
+
428
+ # unet_input = torch.cat([rgb_latent, rgb_fft_latent],dim=1)
429
+
430
+ depth_latent = self.unet(
431
+ rgb_latent, 1, encoder_hidden_states=batch_empty_text_embed
432
+ ).sample # [B, 4, h, w]
433
+
434
+ depth = self.decode_depth(depth_latent)
435
+
436
+ # clip prediction
437
+ depth = torch.clip(depth, -1.0, 1.0)
438
+ # shift to [0, 1]
439
+ depth = (depth + 1.0) / 2.0
440
+
441
+ return depth
442
+
443
+ def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
444
+ """
445
+ Encode RGB image into latent.
446
+
447
+ Args:
448
+ rgb_in (`torch.Tensor`):
449
+ Input RGB image to be encoded.
450
+
451
+ Returns:
452
+ `torch.Tensor`: Image latent.
453
+ """
454
+ # encode
455
+ h = self.vae.encoder(rgb_in)
456
+ moments = self.vae.quant_conv(h)
457
+ mean, logvar = torch.chunk(moments, 2, dim=1)
458
+ # scale latent
459
+ rgb_latent = mean * self.rgb_latent_scale_factor
460
+ return rgb_latent
461
+
462
+ def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
463
+ """
464
+ Decode depth latent into depth map.
465
+
466
+ Args:
467
+ depth_latent (`torch.Tensor`):
468
+ Depth latent to be decoded.
469
+
470
+ Returns:
471
+ `torch.Tensor`: Decoded depth map.
472
+ """
473
+ # scale latent
474
+ depth_latent = depth_latent / self.depth_latent_scale_factor
475
+ # decode
476
+ z = self.vae.post_quant_conv(depth_latent)
477
+ stacked = self.vae.decoder(z)
478
+ # mean of output channels
479
+ depth_mean = stacked.mean(dim=1, keepdim=True)
480
+ return depth_mean
apdepth/modules/__pycache__/unet_2d_blocks.cpython-310.pyc ADDED
Binary file (67.3 kB). View file
 
apdepth/modules/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (40.8 kB). View file
 
apdepth/modules/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
apdepth/modules/unet_2d_condition.py ADDED
@@ -0,0 +1,1314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import (
27
+ ADDED_KV_ATTENTION_PROCESSORS,
28
+ CROSS_ATTENTION_PROCESSORS,
29
+ Attention,
30
+ AttentionProcessor,
31
+ AttnAddedKVProcessor,
32
+ AttnProcessor,
33
+ FusedAttnProcessor2_0,
34
+ )
35
+ from diffusers.models.embeddings import (
36
+ GaussianFourierProjection,
37
+ GLIGENTextBoundingboxProjection,
38
+ ImageHintTimeEmbedding,
39
+ ImageProjection,
40
+ ImageTimeEmbedding,
41
+ TextImageProjection,
42
+ TextImageTimeEmbedding,
43
+ TextTimeEmbedding,
44
+ TimestepEmbedding,
45
+ Timesteps,
46
+ )
47
+ from diffusers.models.modeling_utils import ModelMixin
48
+ from marigold.modules.unet_2d_blocks import (
49
+ get_down_block,
50
+ get_mid_block,
51
+ get_up_block,
52
+ BlockFE,
53
+ )
54
+
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+
59
+ @dataclass
60
+ class UNet2DConditionOutput(BaseOutput):
61
+ """
62
+ The output of [`UNet2DConditionModel`].
63
+
64
+ Args:
65
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
66
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
67
+ """
68
+
69
+ sample: torch.Tensor = None
70
+
71
+
72
+ class UNet2DConditionModel(
73
+ ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
74
+ ):
75
+ r"""
76
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
77
+ shaped output.
78
+
79
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
80
+ for all models (such as downloading or saving).
81
+
82
+ Parameters:
83
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
84
+ Height and width of input/output sample.
85
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
86
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
87
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
88
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
89
+ Whether to flip the sin to cos in the time embedding.
90
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
91
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
92
+ The tuple of downsample blocks to use.
93
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
94
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
95
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
96
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
97
+ The tuple of upsample blocks to use.
98
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
99
+ Whether to include self-attention in the basic transformer blocks, see
100
+ [`~models.attention.BasicTransformerBlock`].
101
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
102
+ The tuple of output channels for each block.
103
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
104
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
105
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
106
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
107
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
108
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
109
+ If `None`, normalization and activation layers is skipped in post-processing.
110
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
111
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
112
+ The dimension of the cross attention features.
113
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
114
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
115
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
116
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
117
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
118
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
119
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
120
+ [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`],
121
+ [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
122
+ encoder_hid_dim (`int`, *optional*, defaults to None):
123
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
124
+ dimension to `cross_attention_dim`.
125
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
126
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
127
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
128
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
129
+ num_attention_heads (`int`, *optional*):
130
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
131
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
132
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
133
+ class_embed_type (`str`, *optional*, defaults to `None`):
134
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
135
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
136
+ addition_embed_type (`str`, *optional*, defaults to `None`):
137
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
138
+ "text". "text" will use the `TextTimeEmbedding` layer.
139
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
140
+ Dimension for the timestep embeddings.
141
+ num_class_embeds (`int`, *optional*, defaults to `None`):
142
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
143
+ class conditioning with `class_embed_type` equal to `None`.
144
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
145
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
146
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
147
+ An optional override for the dimension of the projected time embedding.
148
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
149
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
150
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
151
+ timestep_post_act (`str`, *optional*, defaults to `None`):
152
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
153
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
154
+ The dimension of `cond_proj` layer in the timestep embedding.
155
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
156
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
157
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
158
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
159
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
160
+ embeddings with the class embeddings.
161
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
162
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
163
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
164
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
165
+ otherwise.
166
+ """
167
+
168
+ _supports_gradient_checkpointing = True
169
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
170
+
171
+ @register_to_config
172
+ def __init__(
173
+ self,
174
+ sample_size: Optional[int] = None,
175
+ in_channels: int = 4,
176
+ out_channels: int = 4,
177
+ center_input_sample: bool = False,
178
+ flip_sin_to_cos: bool = True,
179
+ freq_shift: int = 0,
180
+ down_block_types: Tuple[str] = (
181
+ "CrossAttnDownBlock2D",
182
+ "CrossAttnDownBlock2D",
183
+ "CrossAttnDownBlock2D",
184
+ "DownBlock2D",
185
+ ),
186
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
187
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
188
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
189
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
190
+ layers_per_block: Union[int, Tuple[int]] = 2,
191
+ downsample_padding: int = 1,
192
+ mid_block_scale_factor: float = 1,
193
+ dropout: float = 0.0,
194
+ act_fn: str = "silu",
195
+ norm_num_groups: Optional[int] = 32,
196
+ norm_eps: float = 1e-5,
197
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
198
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
199
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
200
+ encoder_hid_dim: Optional[int] = None,
201
+ encoder_hid_dim_type: Optional[str] = None,
202
+ attention_head_dim: Union[int, Tuple[int]] = 8,
203
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
204
+ dual_cross_attention: bool = False,
205
+ use_linear_projection: bool = False,
206
+ class_embed_type: Optional[str] = None,
207
+ addition_embed_type: Optional[str] = None,
208
+ addition_time_embed_dim: Optional[int] = None,
209
+ num_class_embeds: Optional[int] = None,
210
+ upcast_attention: bool = False,
211
+ resnet_time_scale_shift: str = "default",
212
+ resnet_skip_time_act: bool = False,
213
+ resnet_out_scale_factor: float = 1.0,
214
+ time_embedding_type: str = "positional",
215
+ time_embedding_dim: Optional[int] = None,
216
+ time_embedding_act_fn: Optional[str] = None,
217
+ timestep_post_act: Optional[str] = None,
218
+ time_cond_proj_dim: Optional[int] = None,
219
+ conv_in_kernel: int = 3,
220
+ conv_out_kernel: int = 3,
221
+ projection_class_embeddings_input_dim: Optional[int] = None,
222
+ attention_type: str = "default",
223
+ class_embeddings_concat: bool = False,
224
+ mid_block_only_cross_attention: Optional[bool] = None,
225
+ cross_attention_norm: Optional[str] = None,
226
+ addition_embed_type_num_heads: int = 64,
227
+ ):
228
+ super().__init__()
229
+ # print('loaded correct file')
230
+
231
+ self.sample_size = sample_size
232
+
233
+ if num_attention_heads is not None:
234
+ raise ValueError(
235
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
236
+ )
237
+
238
+ # If `num_attention_heads` is not defined (which is the case for most models)
239
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
240
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
241
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
242
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
243
+ # which is why we correct for the naming here.
244
+ num_attention_heads = num_attention_heads or attention_head_dim
245
+
246
+ # Check inputs
247
+ self._check_config(
248
+ down_block_types=down_block_types,
249
+ up_block_types=up_block_types,
250
+ only_cross_attention=only_cross_attention,
251
+ block_out_channels=block_out_channels,
252
+ layers_per_block=layers_per_block,
253
+ cross_attention_dim=cross_attention_dim,
254
+ transformer_layers_per_block=transformer_layers_per_block,
255
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
256
+ attention_head_dim=attention_head_dim,
257
+ num_attention_heads=num_attention_heads,
258
+ )
259
+
260
+ # input
261
+ conv_in_padding = (conv_in_kernel - 1) // 2
262
+ self.conv_in = nn.Conv2d(
263
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
264
+ )
265
+
266
+ # time
267
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
268
+ time_embedding_type,
269
+ block_out_channels=block_out_channels,
270
+ flip_sin_to_cos=flip_sin_to_cos,
271
+ freq_shift=freq_shift,
272
+ time_embedding_dim=time_embedding_dim,
273
+ )
274
+
275
+ self.time_embedding = TimestepEmbedding(
276
+ timestep_input_dim,
277
+ time_embed_dim,
278
+ act_fn=act_fn,
279
+ post_act_fn=timestep_post_act,
280
+ cond_proj_dim=time_cond_proj_dim,
281
+ )
282
+
283
+ self._set_encoder_hid_proj(
284
+ encoder_hid_dim_type,
285
+ cross_attention_dim=cross_attention_dim,
286
+ encoder_hid_dim=encoder_hid_dim,
287
+ )
288
+
289
+ # class embedding
290
+ self._set_class_embedding(
291
+ class_embed_type,
292
+ act_fn=act_fn,
293
+ num_class_embeds=num_class_embeds,
294
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
295
+ time_embed_dim=time_embed_dim,
296
+ timestep_input_dim=timestep_input_dim,
297
+ )
298
+
299
+ self._set_add_embedding(
300
+ addition_embed_type,
301
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
302
+ addition_time_embed_dim=addition_time_embed_dim,
303
+ cross_attention_dim=cross_attention_dim,
304
+ encoder_hid_dim=encoder_hid_dim,
305
+ flip_sin_to_cos=flip_sin_to_cos,
306
+ freq_shift=freq_shift,
307
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
308
+ time_embed_dim=time_embed_dim,
309
+ )
310
+
311
+ if time_embedding_act_fn is None:
312
+ self.time_embed_act = None
313
+ else:
314
+ self.time_embed_act = get_activation(time_embedding_act_fn)
315
+
316
+ self.down_blocks = nn.ModuleList([])
317
+ self.up_blocks = nn.ModuleList([])
318
+
319
+ if isinstance(only_cross_attention, bool):
320
+ if mid_block_only_cross_attention is None:
321
+ mid_block_only_cross_attention = only_cross_attention
322
+
323
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
324
+
325
+ if mid_block_only_cross_attention is None:
326
+ mid_block_only_cross_attention = False
327
+
328
+ if isinstance(num_attention_heads, int):
329
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
330
+
331
+ if isinstance(attention_head_dim, int):
332
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
333
+
334
+ if isinstance(cross_attention_dim, int):
335
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
336
+
337
+ if isinstance(layers_per_block, int):
338
+ layers_per_block = [layers_per_block] * len(down_block_types)
339
+
340
+ if isinstance(transformer_layers_per_block, int):
341
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
342
+
343
+ if class_embeddings_concat:
344
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
345
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
346
+ # regular time embeddings
347
+ blocks_time_embed_dim = time_embed_dim * 2
348
+ else:
349
+ blocks_time_embed_dim = time_embed_dim
350
+
351
+ # down
352
+ output_channel = block_out_channels[0]
353
+ for i, down_block_type in enumerate(down_block_types):
354
+ input_channel = output_channel
355
+ output_channel = block_out_channels[i]
356
+ is_final_block = i == len(block_out_channels) - 1
357
+
358
+ down_block = get_down_block(
359
+ down_block_type,
360
+ num_layers=layers_per_block[i],
361
+ transformer_layers_per_block=transformer_layers_per_block[i],
362
+ in_channels=input_channel,
363
+ out_channels=output_channel,
364
+ temb_channels=blocks_time_embed_dim,
365
+ add_downsample=not is_final_block,
366
+ resnet_eps=norm_eps,
367
+ resnet_act_fn=act_fn,
368
+ resnet_groups=norm_num_groups,
369
+ cross_attention_dim=cross_attention_dim[i],
370
+ num_attention_heads=num_attention_heads[i],
371
+ downsample_padding=downsample_padding,
372
+ dual_cross_attention=dual_cross_attention,
373
+ use_linear_projection=use_linear_projection,
374
+ only_cross_attention=only_cross_attention[i],
375
+ upcast_attention=upcast_attention,
376
+ resnet_time_scale_shift=resnet_time_scale_shift,
377
+ attention_type=attention_type,
378
+ resnet_skip_time_act=resnet_skip_time_act,
379
+ resnet_out_scale_factor=resnet_out_scale_factor,
380
+ cross_attention_norm=cross_attention_norm,
381
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
382
+ dropout=dropout,
383
+ )
384
+ self.down_blocks.append(down_block)
385
+
386
+ # mid
387
+ self.mid_block = get_mid_block(
388
+ mid_block_type,
389
+ temb_channels=blocks_time_embed_dim,
390
+ in_channels=block_out_channels[-1],
391
+ resnet_eps=norm_eps,
392
+ resnet_act_fn=act_fn,
393
+ resnet_groups=norm_num_groups,
394
+ output_scale_factor=mid_block_scale_factor,
395
+ transformer_layers_per_block=transformer_layers_per_block[-1],
396
+ num_attention_heads=num_attention_heads[-1],
397
+ cross_attention_dim=cross_attention_dim[-1],
398
+ dual_cross_attention=dual_cross_attention,
399
+ use_linear_projection=use_linear_projection,
400
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
401
+ upcast_attention=upcast_attention,
402
+ resnet_time_scale_shift=resnet_time_scale_shift,
403
+ attention_type=attention_type,
404
+ resnet_skip_time_act=resnet_skip_time_act,
405
+ cross_attention_norm=cross_attention_norm,
406
+ attention_head_dim=attention_head_dim[-1],
407
+ dropout=dropout,
408
+ )
409
+
410
+ # count how many layers upsample the images
411
+ self.num_upsamplers = 0
412
+
413
+ # up
414
+ reversed_block_out_channels = list(reversed(block_out_channels))
415
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
416
+ reversed_layers_per_block = list(reversed(layers_per_block))
417
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
418
+ reversed_transformer_layers_per_block = (
419
+ list(reversed(transformer_layers_per_block))
420
+ if reverse_transformer_layers_per_block is None
421
+ else reverse_transformer_layers_per_block
422
+ )
423
+ only_cross_attention = list(reversed(only_cross_attention))
424
+
425
+ output_channel = reversed_block_out_channels[0]
426
+ for i, up_block_type in enumerate(up_block_types):
427
+ is_final_block = i == len(block_out_channels) - 1
428
+
429
+ prev_output_channel = output_channel
430
+ output_channel = reversed_block_out_channels[i]
431
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
432
+
433
+ # add upsample block for all BUT final layer
434
+ if not is_final_block:
435
+ add_upsample = True
436
+ self.num_upsamplers += 1
437
+ else:
438
+ add_upsample = False
439
+
440
+ up_block = get_up_block(
441
+ up_block_type,
442
+ num_layers=reversed_layers_per_block[i] + 1,
443
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
444
+ in_channels=input_channel,
445
+ out_channels=output_channel,
446
+ prev_output_channel=prev_output_channel,
447
+ temb_channels=blocks_time_embed_dim,
448
+ add_upsample=add_upsample,
449
+ resnet_eps=norm_eps,
450
+ resnet_act_fn=act_fn,
451
+ resolution_idx=i,
452
+ resnet_groups=norm_num_groups,
453
+ cross_attention_dim=reversed_cross_attention_dim[i],
454
+ num_attention_heads=reversed_num_attention_heads[i],
455
+ dual_cross_attention=dual_cross_attention,
456
+ use_linear_projection=use_linear_projection,
457
+ only_cross_attention=only_cross_attention[i],
458
+ upcast_attention=upcast_attention,
459
+ resnet_time_scale_shift=resnet_time_scale_shift,
460
+ attention_type=attention_type,
461
+ resnet_skip_time_act=resnet_skip_time_act,
462
+ resnet_out_scale_factor=resnet_out_scale_factor,
463
+ cross_attention_norm=cross_attention_norm,
464
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
465
+ dropout=dropout,
466
+ )
467
+ self.up_blocks.append(up_block)
468
+ prev_output_channel = output_channel
469
+
470
+ # out
471
+ if norm_num_groups is not None:
472
+ self.conv_norm_out = nn.GroupNorm(
473
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
474
+ )
475
+
476
+ self.conv_act = get_activation(act_fn)
477
+
478
+ else:
479
+ self.conv_norm_out = None
480
+ self.conv_act = None
481
+
482
+ conv_out_padding = (conv_out_kernel - 1) // 2
483
+ self.conv_out = nn.Conv2d(
484
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
485
+ )
486
+
487
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
488
+
489
+ def _check_config(
490
+ self,
491
+ down_block_types: Tuple[str],
492
+ up_block_types: Tuple[str],
493
+ only_cross_attention: Union[bool, Tuple[bool]],
494
+ block_out_channels: Tuple[int],
495
+ layers_per_block: Union[int, Tuple[int]],
496
+ cross_attention_dim: Union[int, Tuple[int]],
497
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
498
+ reverse_transformer_layers_per_block: bool,
499
+ attention_head_dim: int,
500
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
501
+ ):
502
+ if len(down_block_types) != len(up_block_types):
503
+ raise ValueError(
504
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
505
+ )
506
+
507
+ if len(block_out_channels) != len(down_block_types):
508
+ raise ValueError(
509
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
510
+ )
511
+
512
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
513
+ raise ValueError(
514
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
515
+ )
516
+
517
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
518
+ raise ValueError(
519
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
520
+ )
521
+
522
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
523
+ raise ValueError(
524
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
525
+ )
526
+
527
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
528
+ raise ValueError(
529
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
530
+ )
531
+
532
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
533
+ raise ValueError(
534
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
535
+ )
536
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
537
+ for layer_number_per_block in transformer_layers_per_block:
538
+ if isinstance(layer_number_per_block, list):
539
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
540
+
541
+ def _set_time_proj(
542
+ self,
543
+ time_embedding_type: str,
544
+ block_out_channels: int,
545
+ flip_sin_to_cos: bool,
546
+ freq_shift: float,
547
+ time_embedding_dim: int,
548
+ ) -> Tuple[int, int]:
549
+ if time_embedding_type == "fourier":
550
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
551
+ if time_embed_dim % 2 != 0:
552
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
553
+ self.time_proj = GaussianFourierProjection(
554
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
555
+ )
556
+ timestep_input_dim = time_embed_dim
557
+ elif time_embedding_type == "positional":
558
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
559
+
560
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
561
+ timestep_input_dim = block_out_channels[0]
562
+ else:
563
+ raise ValueError(
564
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
565
+ )
566
+
567
+ return time_embed_dim, timestep_input_dim
568
+
569
+ def _set_encoder_hid_proj(
570
+ self,
571
+ encoder_hid_dim_type: Optional[str],
572
+ cross_attention_dim: Union[int, Tuple[int]],
573
+ encoder_hid_dim: Optional[int],
574
+ ):
575
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
576
+ encoder_hid_dim_type = "text_proj"
577
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
578
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
579
+
580
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
581
+ raise ValueError(
582
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
583
+ )
584
+
585
+ if encoder_hid_dim_type == "text_proj":
586
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
587
+ elif encoder_hid_dim_type == "text_image_proj":
588
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
589
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
590
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
591
+ self.encoder_hid_proj = TextImageProjection(
592
+ text_embed_dim=encoder_hid_dim,
593
+ image_embed_dim=cross_attention_dim,
594
+ cross_attention_dim=cross_attention_dim,
595
+ )
596
+ elif encoder_hid_dim_type == "image_proj":
597
+ # Kandinsky 2.2
598
+ self.encoder_hid_proj = ImageProjection(
599
+ image_embed_dim=encoder_hid_dim,
600
+ cross_attention_dim=cross_attention_dim,
601
+ )
602
+ elif encoder_hid_dim_type is not None:
603
+ raise ValueError(
604
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
605
+ )
606
+ else:
607
+ self.encoder_hid_proj = None
608
+
609
+ def _set_class_embedding(
610
+ self,
611
+ class_embed_type: Optional[str],
612
+ act_fn: str,
613
+ num_class_embeds: Optional[int],
614
+ projection_class_embeddings_input_dim: Optional[int],
615
+ time_embed_dim: int,
616
+ timestep_input_dim: int,
617
+ ):
618
+ if class_embed_type is None and num_class_embeds is not None:
619
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
620
+ elif class_embed_type == "timestep":
621
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
622
+ elif class_embed_type == "identity":
623
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
624
+ elif class_embed_type == "projection":
625
+ if projection_class_embeddings_input_dim is None:
626
+ raise ValueError(
627
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
628
+ )
629
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
630
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
631
+ # 2. it projects from an arbitrary input dimension.
632
+ #
633
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
634
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
635
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
636
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
637
+ elif class_embed_type == "simple_projection":
638
+ if projection_class_embeddings_input_dim is None:
639
+ raise ValueError(
640
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
641
+ )
642
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
643
+ else:
644
+ self.class_embedding = None
645
+
646
+ def _set_add_embedding(
647
+ self,
648
+ addition_embed_type: str,
649
+ addition_embed_type_num_heads: int,
650
+ addition_time_embed_dim: Optional[int],
651
+ flip_sin_to_cos: bool,
652
+ freq_shift: float,
653
+ cross_attention_dim: Optional[int],
654
+ encoder_hid_dim: Optional[int],
655
+ projection_class_embeddings_input_dim: Optional[int],
656
+ time_embed_dim: int,
657
+ ):
658
+ if addition_embed_type == "text":
659
+ if encoder_hid_dim is not None:
660
+ text_time_embedding_from_dim = encoder_hid_dim
661
+ else:
662
+ text_time_embedding_from_dim = cross_attention_dim
663
+
664
+ self.add_embedding = TextTimeEmbedding(
665
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
666
+ )
667
+ elif addition_embed_type == "text_image":
668
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
669
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
670
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
671
+ self.add_embedding = TextImageTimeEmbedding(
672
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
673
+ )
674
+ elif addition_embed_type == "text_time":
675
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
676
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
677
+ elif addition_embed_type == "image":
678
+ # Kandinsky 2.2
679
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
680
+ elif addition_embed_type == "image_hint":
681
+ # Kandinsky 2.2 ControlNet
682
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
683
+ elif addition_embed_type is not None:
684
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
685
+
686
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
687
+ if attention_type in ["gated", "gated-text-image"]:
688
+ positive_len = 768
689
+ if isinstance(cross_attention_dim, int):
690
+ positive_len = cross_attention_dim
691
+ elif isinstance(cross_attention_dim, (list, tuple)):
692
+ positive_len = cross_attention_dim[0]
693
+
694
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
695
+ self.position_net = GLIGENTextBoundingboxProjection(
696
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
697
+ )
698
+
699
+ @property
700
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
701
+ r"""
702
+ Returns:
703
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
704
+ indexed by its weight name.
705
+ """
706
+ # set recursively
707
+ processors = {}
708
+
709
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
710
+ if hasattr(module, "get_processor"):
711
+ processors[f"{name}.processor"] = module.get_processor()
712
+
713
+ for sub_name, child in module.named_children():
714
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
715
+
716
+ return processors
717
+
718
+ for name, module in self.named_children():
719
+ fn_recursive_add_processors(name, module, processors)
720
+
721
+ return processors
722
+
723
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
724
+ r"""
725
+ Sets the attention processor to use to compute attention.
726
+
727
+ Parameters:
728
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
729
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
730
+ for **all** `Attention` layers.
731
+
732
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
733
+ processor. This is strongly recommended when setting trainable attention processors.
734
+
735
+ """
736
+ count = len(self.attn_processors.keys())
737
+
738
+ if isinstance(processor, dict) and len(processor) != count:
739
+ raise ValueError(
740
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
741
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
742
+ )
743
+
744
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
745
+ if hasattr(module, "set_processor"):
746
+ if not isinstance(processor, dict):
747
+ module.set_processor(processor)
748
+ else:
749
+ module.set_processor(processor.pop(f"{name}.processor"))
750
+
751
+ for sub_name, child in module.named_children():
752
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
753
+
754
+ for name, module in self.named_children():
755
+ fn_recursive_attn_processor(name, module, processor)
756
+
757
+ def set_default_attn_processor(self):
758
+ """
759
+ Disables custom attention processors and sets the default attention implementation.
760
+ """
761
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
762
+ processor = AttnAddedKVProcessor()
763
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
764
+ processor = AttnProcessor()
765
+ else:
766
+ raise ValueError(
767
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
768
+ )
769
+
770
+ self.set_attn_processor(processor)
771
+
772
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
773
+ r"""
774
+ Enable sliced attention computation.
775
+
776
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
777
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
778
+
779
+ Args:
780
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
781
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
782
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
783
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
784
+ must be a multiple of `slice_size`.
785
+ """
786
+ sliceable_head_dims = []
787
+
788
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
789
+ if hasattr(module, "set_attention_slice"):
790
+ sliceable_head_dims.append(module.sliceable_head_dim)
791
+
792
+ for child in module.children():
793
+ fn_recursive_retrieve_sliceable_dims(child)
794
+
795
+ # retrieve number of attention layers
796
+ for module in self.children():
797
+ fn_recursive_retrieve_sliceable_dims(module)
798
+
799
+ num_sliceable_layers = len(sliceable_head_dims)
800
+
801
+ if slice_size == "auto":
802
+ # half the attention head size is usually a good trade-off between
803
+ # speed and memory
804
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
805
+ elif slice_size == "max":
806
+ # make smallest slice possible
807
+ slice_size = num_sliceable_layers * [1]
808
+
809
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
810
+
811
+ if len(slice_size) != len(sliceable_head_dims):
812
+ raise ValueError(
813
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
814
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
815
+ )
816
+
817
+ for i in range(len(slice_size)):
818
+ size = slice_size[i]
819
+ dim = sliceable_head_dims[i]
820
+ if size is not None and size > dim:
821
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
822
+
823
+ # Recursively walk through all the children.
824
+ # Any children which exposes the set_attention_slice method
825
+ # gets the message
826
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
827
+ if hasattr(module, "set_attention_slice"):
828
+ module.set_attention_slice(slice_size.pop())
829
+
830
+ for child in module.children():
831
+ fn_recursive_set_attention_slice(child, slice_size)
832
+
833
+ reversed_slice_size = list(reversed(slice_size))
834
+ for module in self.children():
835
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
836
+
837
+ def _set_gradient_checkpointing(self, module, value=False):
838
+ if hasattr(module, "gradient_checkpointing"):
839
+ module.gradient_checkpointing = value
840
+
841
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
842
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
843
+
844
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
845
+
846
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
847
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
848
+
849
+ Args:
850
+ s1 (`float`):
851
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
852
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
853
+ s2 (`float`):
854
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
855
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
856
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
857
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
858
+ """
859
+ for i, upsample_block in enumerate(self.up_blocks):
860
+ setattr(upsample_block, "s1", s1)
861
+ setattr(upsample_block, "s2", s2)
862
+ setattr(upsample_block, "b1", b1)
863
+ setattr(upsample_block, "b2", b2)
864
+
865
+ def disable_freeu(self):
866
+ """Disables the FreeU mechanism."""
867
+ freeu_keys = {"s1", "s2", "b1", "b2"}
868
+ for i, upsample_block in enumerate(self.up_blocks):
869
+ for k in freeu_keys:
870
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
871
+ setattr(upsample_block, k, None)
872
+
873
+ def fuse_qkv_projections(self):
874
+ """
875
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
876
+ are fused. For cross-attention modules, key and value projection matrices are fused.
877
+
878
+ <Tip warning={true}>
879
+
880
+ This API is 🧪 experimental.
881
+
882
+ </Tip>
883
+ """
884
+ self.original_attn_processors = None
885
+
886
+ for _, attn_processor in self.attn_processors.items():
887
+ if "Added" in str(attn_processor.__class__.__name__):
888
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
889
+
890
+ self.original_attn_processors = self.attn_processors
891
+
892
+ for module in self.modules():
893
+ if isinstance(module, Attention):
894
+ module.fuse_projections(fuse=True)
895
+
896
+ self.set_attn_processor(FusedAttnProcessor2_0())
897
+
898
+ def unfuse_qkv_projections(self):
899
+ """Disables the fused QKV projection if enabled.
900
+
901
+ <Tip warning={true}>
902
+
903
+ This API is 🧪 experimental.
904
+
905
+ </Tip>
906
+
907
+ """
908
+ if self.original_attn_processors is not None:
909
+ self.set_attn_processor(self.original_attn_processors)
910
+
911
+ def get_time_embed(
912
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
913
+ ) -> Optional[torch.Tensor]:
914
+ timesteps = timestep
915
+ if not torch.is_tensor(timesteps):
916
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
917
+ # This would be a good case for the `match` statement (Python 3.10+)
918
+ is_mps = sample.device.type == "mps"
919
+ if isinstance(timestep, float):
920
+ dtype = torch.float32 if is_mps else torch.float64
921
+ else:
922
+ dtype = torch.int32 if is_mps else torch.int64
923
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
924
+ elif len(timesteps.shape) == 0:
925
+ timesteps = timesteps[None].to(sample.device)
926
+
927
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
928
+ timesteps = timesteps.expand(sample.shape[0])
929
+
930
+ t_emb = self.time_proj(timesteps)
931
+ # `Timesteps` does not contain any weights and will always return f32 tensors
932
+ # but time_embedding might actually be running in fp16. so we need to cast here.
933
+ # there might be better ways to encapsulate this.
934
+ t_emb = t_emb.to(dtype=sample.dtype)
935
+ return t_emb
936
+
937
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
938
+ class_emb = None
939
+ if self.class_embedding is not None:
940
+ if class_labels is None:
941
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
942
+
943
+ if self.config.class_embed_type == "timestep":
944
+ class_labels = self.time_proj(class_labels)
945
+
946
+ # `Timesteps` does not contain any weights and will always return f32 tensors
947
+ # there might be better ways to encapsulate this.
948
+ class_labels = class_labels.to(dtype=sample.dtype)
949
+
950
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
951
+ return class_emb
952
+
953
+ def get_aug_embed(
954
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
955
+ ) -> Optional[torch.Tensor]:
956
+ aug_emb = None
957
+ if self.config.addition_embed_type == "text":
958
+ aug_emb = self.add_embedding(encoder_hidden_states)
959
+ elif self.config.addition_embed_type == "text_image":
960
+ # Kandinsky 2.1 - style
961
+ if "image_embeds" not in added_cond_kwargs:
962
+ raise ValueError(
963
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
964
+ )
965
+
966
+ image_embs = added_cond_kwargs.get("image_embeds")
967
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
968
+ aug_emb = self.add_embedding(text_embs, image_embs)
969
+ elif self.config.addition_embed_type == "text_time":
970
+ # SDXL - style
971
+ if "text_embeds" not in added_cond_kwargs:
972
+ raise ValueError(
973
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
974
+ )
975
+ text_embeds = added_cond_kwargs.get("text_embeds")
976
+ if "time_ids" not in added_cond_kwargs:
977
+ raise ValueError(
978
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
979
+ )
980
+ time_ids = added_cond_kwargs.get("time_ids")
981
+ time_embeds = self.add_time_proj(time_ids.flatten())
982
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
983
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
984
+ add_embeds = add_embeds.to(emb.dtype)
985
+ aug_emb = self.add_embedding(add_embeds)
986
+ elif self.config.addition_embed_type == "image":
987
+ # Kandinsky 2.2 - style
988
+ if "image_embeds" not in added_cond_kwargs:
989
+ raise ValueError(
990
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
991
+ )
992
+ image_embs = added_cond_kwargs.get("image_embeds")
993
+ aug_emb = self.add_embedding(image_embs)
994
+ elif self.config.addition_embed_type == "image_hint":
995
+ # Kandinsky 2.2 - style
996
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
997
+ raise ValueError(
998
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
999
+ )
1000
+ image_embs = added_cond_kwargs.get("image_embeds")
1001
+ hint = added_cond_kwargs.get("hint")
1002
+ aug_emb = self.add_embedding(image_embs, hint)
1003
+ return aug_emb
1004
+
1005
+ def process_encoder_hidden_states(
1006
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1007
+ ) -> torch.Tensor:
1008
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1009
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1010
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1011
+ # Kandinsky 2.1 - style
1012
+ if "image_embeds" not in added_cond_kwargs:
1013
+ raise ValueError(
1014
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1015
+ )
1016
+
1017
+ image_embeds = added_cond_kwargs.get("image_embeds")
1018
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1019
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1020
+ # Kandinsky 2.2 - style
1021
+ if "image_embeds" not in added_cond_kwargs:
1022
+ raise ValueError(
1023
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1024
+ )
1025
+ image_embeds = added_cond_kwargs.get("image_embeds")
1026
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1027
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1028
+ if "image_embeds" not in added_cond_kwargs:
1029
+ raise ValueError(
1030
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1031
+ )
1032
+
1033
+ if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1034
+ encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1035
+
1036
+ image_embeds = added_cond_kwargs.get("image_embeds")
1037
+ image_embeds = self.encoder_hid_proj(image_embeds)
1038
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1039
+ return encoder_hidden_states
1040
+
1041
+ def forward(
1042
+ self,
1043
+ sample: torch.Tensor,
1044
+ timestep: Union[torch.Tensor, float, int],
1045
+ encoder_hidden_states: torch.Tensor,
1046
+ class_labels: Optional[torch.Tensor] = None,
1047
+ timestep_cond: Optional[torch.Tensor] = None,
1048
+ attention_mask: Optional[torch.Tensor] = None,
1049
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1050
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1051
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1052
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1053
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1054
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1055
+ return_dict: bool = True,
1056
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1057
+ r"""
1058
+ The [`UNet2DConditionModel`] forward method.
1059
+
1060
+ Args:
1061
+ sample (`torch.Tensor`):
1062
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1063
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1064
+ encoder_hidden_states (`torch.Tensor`):
1065
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1066
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1067
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1068
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1069
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1070
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1071
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1072
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1073
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1074
+ negative values to the attention scores corresponding to "discard" tokens.
1075
+ cross_attention_kwargs (`dict`, *optional*):
1076
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1077
+ `self.processor` in
1078
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1079
+ added_cond_kwargs: (`dict`, *optional*):
1080
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1081
+ are passed along to the UNet blocks.
1082
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1083
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1084
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1085
+ A tensor that if specified is added to the residual of the middle unet block.
1086
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1087
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1088
+ encoder_attention_mask (`torch.Tensor`):
1089
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1090
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1091
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1092
+ return_dict (`bool`, *optional*, defaults to `True`):
1093
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1094
+ tuple.
1095
+
1096
+ Returns:
1097
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1098
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1099
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1100
+ """
1101
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1102
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1103
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1104
+ # on the fly if necessary.
1105
+ default_overall_up_factor = 2**self.num_upsamplers
1106
+
1107
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1108
+ forward_upsample_size = False
1109
+ upsample_size = None
1110
+
1111
+ for dim in sample.shape[-2:]:
1112
+ if dim % default_overall_up_factor != 0:
1113
+ # Forward upsample size to force interpolation output size.
1114
+ forward_upsample_size = True
1115
+ break
1116
+
1117
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1118
+ # expects mask of shape:
1119
+ # [batch, key_tokens]
1120
+ # adds singleton query_tokens dimension:
1121
+ # [batch, 1, key_tokens]
1122
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1123
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1124
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1125
+ if attention_mask is not None:
1126
+ # assume that mask is expressed as:
1127
+ # (1 = keep, 0 = discard)
1128
+ # convert mask into a bias that can be added to attention scores:
1129
+ # (keep = +0, discard = -10000.0)
1130
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1131
+ attention_mask = attention_mask.unsqueeze(1)
1132
+
1133
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1134
+ if encoder_attention_mask is not None:
1135
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1136
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1137
+
1138
+ # 0. center input if necessary
1139
+ if self.config.center_input_sample:
1140
+ sample = 2 * sample - 1.0
1141
+
1142
+ # 1. time
1143
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1144
+ emb = self.time_embedding(t_emb, timestep_cond)
1145
+ aug_emb = None
1146
+
1147
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1148
+ if class_emb is not None:
1149
+ if self.config.class_embeddings_concat:
1150
+ emb = torch.cat([emb, class_emb], dim=-1)
1151
+ else:
1152
+ emb = emb + class_emb
1153
+
1154
+ aug_emb = self.get_aug_embed(
1155
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1156
+ )
1157
+ if self.config.addition_embed_type == "image_hint":
1158
+ aug_emb, hint = aug_emb
1159
+ sample = torch.cat([sample, hint], dim=1)
1160
+
1161
+ emb = emb + aug_emb if aug_emb is not None else emb
1162
+
1163
+ if self.time_embed_act is not None:
1164
+ emb = self.time_embed_act(emb)
1165
+
1166
+ encoder_hidden_states = self.process_encoder_hidden_states(
1167
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1168
+ )
1169
+
1170
+ # 2. pre-process
1171
+ sample = self.conv_in(sample)
1172
+
1173
+ # 2.5 GLIGEN position net
1174
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1175
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1176
+ gligen_args = cross_attention_kwargs.pop("gligen")
1177
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1178
+
1179
+ # 3. down
1180
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1181
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1182
+ if cross_attention_kwargs is not None:
1183
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1184
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1185
+ else:
1186
+ lora_scale = 1.0
1187
+
1188
+ if USE_PEFT_BACKEND:
1189
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1190
+ scale_lora_layers(self, lora_scale)
1191
+
1192
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1193
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1194
+ is_adapter = down_intrablock_additional_residuals is not None
1195
+ # maintain backward compatibility for legacy usage, where
1196
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1197
+ # but can only use one or the other
1198
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1199
+ deprecate(
1200
+ "T2I should not use down_block_additional_residuals",
1201
+ "1.3.0",
1202
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1203
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1204
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1205
+ standard_warn=False,
1206
+ )
1207
+ down_intrablock_additional_residuals = down_block_additional_residuals
1208
+ is_adapter = True
1209
+
1210
+ down_block_res_samples = (sample,)
1211
+ for downsample_block in self.down_blocks:
1212
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1213
+ # For t2i-adapter CrossAttnDownBlock2D
1214
+ additional_residuals = {}
1215
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1216
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1217
+
1218
+ sample, res_samples = downsample_block(
1219
+ hidden_states=sample,
1220
+ temb=emb,
1221
+ encoder_hidden_states=encoder_hidden_states,
1222
+ attention_mask=attention_mask,
1223
+ cross_attention_kwargs=cross_attention_kwargs,
1224
+ encoder_attention_mask=encoder_attention_mask,
1225
+ **additional_residuals,
1226
+ )
1227
+ else:
1228
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1229
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1230
+ sample += down_intrablock_additional_residuals.pop(0)
1231
+
1232
+ down_block_res_samples += res_samples
1233
+
1234
+ if is_controlnet:
1235
+ new_down_block_res_samples = ()
1236
+
1237
+ for down_block_res_sample, down_block_additional_residual in zip(
1238
+ down_block_res_samples, down_block_additional_residuals
1239
+ ):
1240
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1241
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1242
+
1243
+ down_block_res_samples = new_down_block_res_samples
1244
+
1245
+ # 4. mid
1246
+ if self.mid_block is not None:
1247
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1248
+ sample = self.mid_block(
1249
+ sample,
1250
+ emb,
1251
+ encoder_hidden_states=encoder_hidden_states,
1252
+ attention_mask=attention_mask,
1253
+ cross_attention_kwargs=cross_attention_kwargs,
1254
+ encoder_attention_mask=encoder_attention_mask,
1255
+ )
1256
+ else:
1257
+ sample = self.mid_block(sample, emb)
1258
+
1259
+ # To support T2I-Adapter-XL
1260
+ if (
1261
+ is_adapter
1262
+ and len(down_intrablock_additional_residuals) > 0
1263
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1264
+ ):
1265
+ sample += down_intrablock_additional_residuals.pop(0)
1266
+
1267
+ if is_controlnet:
1268
+ sample = sample + mid_block_additional_residual
1269
+
1270
+ # 5. up
1271
+ for i, upsample_block in enumerate(self.up_blocks):
1272
+ is_final_block = i == len(self.up_blocks) - 1
1273
+
1274
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1275
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1276
+
1277
+ # if we have not reached the final block and need to forward the
1278
+ # upsample size, we do it here
1279
+ if not is_final_block and forward_upsample_size:
1280
+ upsample_size = down_block_res_samples[-1].shape[2:]
1281
+
1282
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1283
+ sample = upsample_block(
1284
+ hidden_states=sample,
1285
+ temb=emb,
1286
+ res_hidden_states_tuple=res_samples,
1287
+ encoder_hidden_states=encoder_hidden_states,
1288
+ cross_attention_kwargs=cross_attention_kwargs,
1289
+ upsample_size=upsample_size,
1290
+ attention_mask=attention_mask,
1291
+ encoder_attention_mask=encoder_attention_mask,
1292
+ )
1293
+ else:
1294
+ sample = upsample_block(
1295
+ hidden_states=sample,
1296
+ temb=emb,
1297
+ res_hidden_states_tuple=res_samples,
1298
+ upsample_size=upsample_size,
1299
+ )
1300
+
1301
+ # 6. post-process
1302
+ if self.conv_norm_out:
1303
+ sample = self.conv_norm_out(sample)
1304
+ sample = self.conv_act(sample)
1305
+ sample = self.conv_out(sample)
1306
+
1307
+ if USE_PEFT_BACKEND:
1308
+ # remove `lora_scale` from each PEFT layer
1309
+ unscale_lora_layers(self, lora_scale)
1310
+
1311
+ if not return_dict:
1312
+ return (sample,)
1313
+
1314
+ return UNet2DConditionOutput(sample=sample)
apdepth/util/__pycache__/batchsize.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
apdepth/util/__pycache__/batchsize.cpython-312.pyc ADDED
Binary file (2.68 kB). View file
 
apdepth/util/__pycache__/ensemble.cpython-310.pyc ADDED
Binary file (6.51 kB). View file
 
apdepth/util/__pycache__/ensemble.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
apdepth/util/__pycache__/image_util.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
apdepth/util/__pycache__/image_util.cpython-312.pyc ADDED
Binary file (4.87 kB). View file
 
apdepth/util/batchsize.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ import torch
22
+ import math
23
+
24
+
25
+ # Search table for suggested max. inference batch size
26
+ bs_search_table = [
27
+ # tested on A100-PCIE-80GB
28
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
29
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
30
+ # tested on A100-PCIE-40GB
31
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
32
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
33
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
34
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
35
+ # tested on RTX3090, RTX4090
36
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
37
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
38
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
39
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
40
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
41
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
42
+ # tested on GTX1080Ti
43
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
44
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
45
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
46
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
47
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
48
+ ]
49
+
50
+
51
+ def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
52
+ """
53
+ Automatically search for suitable operating batch size.
54
+
55
+ Args:
56
+ ensemble_size (`int`):
57
+ Number of predictions to be ensembled.
58
+ input_res (`int`):
59
+ Operating resolution of the input image.
60
+
61
+ Returns:
62
+ `int`: Operating batch size.
63
+ """
64
+ if not torch.cuda.is_available():
65
+ return 1
66
+
67
+ total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
68
+ filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
69
+ for settings in sorted(
70
+ filtered_bs_search_table,
71
+ key=lambda k: (k["res"], -k["total_vram"]),
72
+ ):
73
+ if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
74
+ bs = settings["bs"]
75
+ if bs > ensemble_size:
76
+ bs = ensemble_size
77
+ elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
78
+ bs = math.ceil(ensemble_size / 2)
79
+ return bs
80
+
81
+ return 1
apdepth/util/ensemble.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------------------------
15
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
16
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
+ # More information about the method can be found at https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+
21
+ from functools import partial
22
+ from typing import Optional, Tuple
23
+
24
+ import numpy as np
25
+ import torch
26
+
27
+ from .image_util import get_tv_resample_method, resize_max_res
28
+
29
+
30
+ def inter_distances(tensors: torch.Tensor):
31
+ """
32
+ To calculate the distance between each two depth maps.
33
+ """
34
+ distances = []
35
+ for i, j in torch.combinations(torch.arange(tensors.shape[0])):
36
+ arr1 = tensors[i : i + 1]
37
+ arr2 = tensors[j : j + 1]
38
+ distances.append(arr1 - arr2)
39
+ dist = torch.concatenate(distances, dim=0)
40
+ return dist
41
+
42
+
43
+ def ensemble_depth(
44
+ depth: torch.Tensor,
45
+ scale_invariant: bool = True,
46
+ shift_invariant: bool = True,
47
+ output_uncertainty: bool = False,
48
+ reduction: str = "median",
49
+ regularizer_strength: float = 0.02,
50
+ max_iter: int = 2,
51
+ tol: float = 1e-3,
52
+ max_res: int = 1024,
53
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
54
+ """
55
+ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
56
+ number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
57
+ depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
58
+ alignment happens when the predictions have one or more degrees of freedom, that is when they are either
59
+ affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
60
+ `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
61
+ alignment is skipped and only ensembling is performed.
62
+
63
+ Args:
64
+ depth (`torch.Tensor`):
65
+ Input ensemble depth maps.
66
+ scale_invariant (`bool`, *optional*, defaults to `True`):
67
+ Whether to treat predictions as scale-invariant.
68
+ shift_invariant (`bool`, *optional*, defaults to `True`):
69
+ Whether to treat predictions as shift-invariant.
70
+ output_uncertainty (`bool`, *optional*, defaults to `False`):
71
+ Whether to output uncertainty map.
72
+ reduction (`str`, *optional*, defaults to `"median"`):
73
+ Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
74
+ `"median"`.
75
+ regularizer_strength (`float`, *optional*, defaults to `0.02`):
76
+ Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
77
+ max_iter (`int`, *optional*, defaults to `2`):
78
+ Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
79
+ argument.
80
+ tol (`float`, *optional*, defaults to `1e-3`):
81
+ Alignment solver tolerance. The solver stops when the tolerance is reached.
82
+ max_res (`int`, *optional*, defaults to `1024`):
83
+ Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
84
+ Returns:
85
+ A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
86
+ `(1, 1, H, W)`.
87
+ """
88
+ if depth.dim() != 4 or depth.shape[1] != 1:
89
+ raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
90
+ if reduction not in ("mean", "median"):
91
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
92
+ if not scale_invariant and shift_invariant:
93
+ raise ValueError("Pure shift-invariant ensembling is not supported.")
94
+
95
+ def init_param(depth: torch.Tensor):
96
+ init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
97
+ init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
98
+
99
+ if scale_invariant and shift_invariant:
100
+ init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
101
+ init_t = -init_s * init_min
102
+ param = torch.cat((init_s, init_t)).cpu().numpy()
103
+ elif scale_invariant:
104
+ init_s = 1.0 / init_max.clamp(min=1e-6)
105
+ param = init_s.cpu().numpy()
106
+ else:
107
+ raise ValueError("Unrecognized alignment.")
108
+
109
+ return param
110
+
111
+ def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
112
+ if scale_invariant and shift_invariant:
113
+ s, t = np.split(param, 2)
114
+ s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
115
+ t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
116
+ out = depth * s + t
117
+ elif scale_invariant:
118
+ s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
119
+ out = depth * s
120
+ else:
121
+ raise ValueError("Unrecognized alignment.")
122
+ return out
123
+
124
+ def ensemble(
125
+ depth_aligned: torch.Tensor, return_uncertainty: bool = False
126
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
127
+ uncertainty = None
128
+ if reduction == "mean":
129
+ prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
130
+ if return_uncertainty:
131
+ uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
132
+ elif reduction == "median":
133
+ prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
134
+ if return_uncertainty:
135
+ uncertainty = torch.median(
136
+ torch.abs(depth_aligned - prediction), dim=0, keepdim=True
137
+ ).values
138
+ else:
139
+ raise ValueError(f"Unrecognized reduction method: {reduction}.")
140
+ return prediction, uncertainty
141
+
142
+ def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
143
+ cost = 0.0
144
+ depth_aligned = align(depth, param)
145
+
146
+ for i, j in torch.combinations(torch.arange(ensemble_size)):
147
+ diff = depth_aligned[i] - depth_aligned[j]
148
+ cost += (diff**2).mean().sqrt().item()
149
+
150
+ if regularizer_strength > 0:
151
+ prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
152
+ err_near = (0.0 - prediction.min()).abs().item()
153
+ err_far = (1.0 - prediction.max()).abs().item()
154
+ cost += (err_near + err_far) * regularizer_strength
155
+
156
+ return cost
157
+
158
+ def compute_param(depth: torch.Tensor):
159
+ import scipy
160
+
161
+ depth_to_align = depth.to(torch.float32)
162
+ if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
163
+ depth_to_align = resize_max_res(
164
+ depth_to_align, max_res, get_tv_resample_method("nearest-exact")
165
+ )
166
+
167
+ param = init_param(depth_to_align)
168
+
169
+ res = scipy.optimize.minimize(
170
+ partial(cost_fn, depth=depth_to_align),
171
+ param,
172
+ method="BFGS",
173
+ tol=tol,
174
+ options={"maxiter": max_iter, "disp": False},
175
+ )
176
+
177
+ return res.x
178
+
179
+ requires_aligning = scale_invariant or shift_invariant
180
+ ensemble_size = depth.shape[0]
181
+
182
+ if requires_aligning:
183
+ param = compute_param(depth)
184
+ depth = align(depth, param)
185
+
186
+ depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
187
+
188
+ depth_max = depth.max()
189
+ if scale_invariant and shift_invariant:
190
+ depth_min = depth.min()
191
+ elif scale_invariant:
192
+ depth_min = 0
193
+ else:
194
+ raise ValueError("Unrecognized alignment.")
195
+ depth_range = (depth_max - depth_min).clamp(min=1e-6)
196
+ depth = (depth - depth_min) / depth_range
197
+ if output_uncertainty:
198
+ uncertainty /= depth_range
199
+
200
+ return depth, uncertainty # [1,1,H,W], [1,1,H,W]
apdepth/util/image_util.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
+ # Last modified: 2024-05-24
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
17
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
+ # More information about the method can be found at https://marigoldmonodepth.github.io
19
+ # --------------------------------------------------------------------------
20
+
21
+
22
+ import matplotlib
23
+ import numpy as np
24
+ import torch
25
+ from torchvision.transforms import InterpolationMode
26
+ from torchvision.transforms.functional import resize
27
+
28
+
29
+ def colorize_depth_maps(
30
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
31
+ ):
32
+ """
33
+ Colorize depth maps.
34
+ """
35
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
36
+
37
+ if isinstance(depth_map, torch.Tensor):
38
+ depth = depth_map.detach().squeeze().numpy()
39
+ elif isinstance(depth_map, np.ndarray):
40
+ depth = depth_map.copy().squeeze()
41
+ # reshape to [ (B,) H, W ]
42
+ if depth.ndim < 3:
43
+ depth = depth[np.newaxis, :, :]
44
+
45
+ # colorize
46
+ cm = matplotlib.colormaps[cmap]
47
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
48
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
49
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
50
+
51
+ if valid_mask is not None:
52
+ if isinstance(depth_map, torch.Tensor):
53
+ valid_mask = valid_mask.detach().numpy()
54
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
55
+ if valid_mask.ndim < 3:
56
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
57
+ else:
58
+ valid_mask = valid_mask[:, np.newaxis, :, :]
59
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
60
+ img_colored_np[~valid_mask] = 0
61
+
62
+ if isinstance(depth_map, torch.Tensor):
63
+ img_colored = torch.from_numpy(img_colored_np).float()
64
+ elif isinstance(depth_map, np.ndarray):
65
+ img_colored = img_colored_np
66
+
67
+ return img_colored
68
+
69
+
70
+ def chw2hwc(chw):
71
+ assert 3 == len(chw.shape)
72
+ if isinstance(chw, torch.Tensor):
73
+ hwc = torch.permute(chw, (1, 2, 0))
74
+ elif isinstance(chw, np.ndarray):
75
+ hwc = np.moveaxis(chw, 0, -1)
76
+ return hwc
77
+
78
+
79
+ def resize_max_res(
80
+ img: torch.Tensor,
81
+ max_edge_resolution: int,
82
+ resample_method: InterpolationMode = InterpolationMode.BILINEAR,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Resize image to limit maximum edge length while keeping aspect ratio.
86
+
87
+ Args:
88
+ img (`torch.Tensor`):
89
+ Image tensor to be resized. Expected shape: [B, C, H, W]
90
+ max_edge_resolution (`int`):
91
+ Maximum edge length (pixel).
92
+ resample_method (`PIL.Image.Resampling`):
93
+ Resampling method used to resize images.
94
+
95
+ Returns:
96
+ `torch.Tensor`: Resized image.
97
+ """
98
+ assert 4 == img.dim(), f"Invalid input shape {img.shape}"
99
+
100
+ original_height, original_width = img.shape[-2:]
101
+ downscale_factor = min(
102
+ max_edge_resolution / original_width, max_edge_resolution / original_height
103
+ )
104
+
105
+ new_width = int(original_width * downscale_factor)
106
+ new_height = int(original_height * downscale_factor)
107
+
108
+ resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
109
+ return resized_img
110
+
111
+
112
+ def get_tv_resample_method(method_str: str) -> InterpolationMode:
113
+ resample_method_dict = {
114
+ "bilinear": InterpolationMode.BILINEAR,
115
+ "bicubic": InterpolationMode.BICUBIC,
116
+ "nearest": InterpolationMode.NEAREST_EXACT,
117
+ "nearest-exact": InterpolationMode.NEAREST_EXACT,
118
+ }
119
+ resample_method = resample_method_dict.get(method_str, None)
120
+ if resample_method is None:
121
+ raise ValueError(f"Unknown resampling method: {resample_method}")
122
+ else:
123
+ return resample_method