GonzaloMG commited on
Commit
63f09df
1 Parent(s): aafb6de

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +478 -0
pipeline.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # Marigold project website: https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+ # @GonzaloMartinGarcia
21
+ # Inference Pipeline for End-to-End Marigold and Stable Diffusion Depth Estimators
22
+ # ----------------------------------------------------------------------------------
23
+ # A streamlined version of the official MarigoldDepthPipeline from diffusers:
24
+ # https://github.com/huggingface/diffusers/blob/a98a839de75f1ad82d8d200c3bc2e4ff89929081/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py#L96
25
+ #
26
+ # This implementation is meant for use with the diffusers custom_pipeline feature.
27
+ # Modifications from the original code are marked with '# add' comments.
28
+
29
+ from dataclasses import dataclass
30
+ from typing import Any, Dict, List, Optional, Tuple, Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPTextModel, CLIPTokenizer
37
+
38
+ from diffusers.image_processor import PipelineImageInput
39
+ from diffusers.models import (
40
+ AutoencoderKL,
41
+ UNet2DConditionModel,
42
+ )
43
+ from diffusers.schedulers import (
44
+ DDIMScheduler,
45
+ )
46
+ from diffusers.utils import (
47
+ BaseOutput,
48
+ logging,
49
+ )
50
+ from diffusers import DiffusionPipeline
51
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
52
+
53
+ # add
54
+ def zeros_tensor(
55
+ shape: Union[Tuple, List],
56
+ device: Optional["torch.device"] = None,
57
+ dtype: Optional["torch.dtype"] = None,
58
+ layout: Optional["torch.layout"] = None,
59
+ ):
60
+ """
61
+ A helper function to create tensors of zeros on the desired `device`.
62
+ Mirrors randn_tensor from diffusers.utils.torch_utils.
63
+ """
64
+ layout = layout or torch.strided
65
+ device = device or torch.device("cpu")
66
+ latents = torch.zeros(list(shape), dtype=dtype, layout=layout).to(device)
67
+ return latents
68
+
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+ @dataclass
73
+ class E2EMarigoldDepthOutput(BaseOutput):
74
+ """
75
+ Output class for Marigold monocular depth prediction pipeline.
76
+
77
+ Args:
78
+ prediction (`np.ndarray`, `torch.Tensor`):
79
+ Predicted depth maps with values in the range [0, 1]. The shape is always $numimages \times 1 \times height
80
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
81
+ latent (`None`, `torch.Tensor`):
82
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
83
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
84
+ """
85
+
86
+ prediction: Union[np.ndarray, torch.Tensor]
87
+ latent: Union[None, torch.Tensor]
88
+
89
+
90
+ class E2EMarigoldDepthPipeline(DiffusionPipeline):
91
+ """
92
+ # add
93
+ Pipeline for monocular depth estimation using the E2E FT Marigold and SD method: https://gonzalomartingarcia.github.io/diffusion-e2e-ft/
94
+ Implementation is built upon Marigold: https://marigoldmonodepth.github.io
95
+
96
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
97
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
98
+
99
+ Args:
100
+ unet (`UNet2DConditionModel`):
101
+ Conditional U-Net to denoise the depth latent, conditioned on image latent.
102
+ vae (`AutoencoderKL`):
103
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
104
+ representations.
105
+ scheduler (`DDIMScheduler`):
106
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
107
+ text_encoder (`CLIPTextModel`):
108
+ Text-encoder, for empty text embedding.
109
+ tokenizer (`CLIPTokenizer`):
110
+ CLIP tokenizer.
111
+ prediction_type (`str`, *optional*):
112
+ Type of predictions made by the model.
113
+ default_processing_resolution (`int`, *optional*):
114
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
115
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
116
+ default value is used. This is required to ensure reasonable results with various model flavors trained
117
+ with varying optimal processing resolution values.
118
+ """
119
+
120
+ model_cpu_offload_seq = "text_encoder->unet->vae"
121
+ supported_prediction_types = ("depth", "disparity")
122
+
123
+ def __init__(
124
+ self,
125
+ unet: UNet2DConditionModel,
126
+ vae: AutoencoderKL,
127
+ scheduler: Union[DDIMScheduler],
128
+ text_encoder: CLIPTextModel,
129
+ tokenizer: CLIPTokenizer,
130
+ prediction_type: Optional[str] = None,
131
+ default_processing_resolution: Optional[int] = None,
132
+ ):
133
+ super().__init__()
134
+
135
+ if prediction_type not in self.supported_prediction_types:
136
+ logger.warning(
137
+ f"Potentially unsupported `prediction_type='{prediction_type}'`; values supported by the pipeline: "
138
+ f"{self.supported_prediction_types}."
139
+ )
140
+
141
+ self.register_modules(
142
+ unet=unet,
143
+ vae=vae,
144
+ scheduler=scheduler,
145
+ text_encoder=text_encoder,
146
+ tokenizer=tokenizer,
147
+ )
148
+ self.register_to_config(
149
+ prediction_type=prediction_type,
150
+ default_processing_resolution=default_processing_resolution,
151
+ )
152
+
153
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
154
+ self.default_processing_resolution = default_processing_resolution
155
+ self.empty_text_embedding = None
156
+
157
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
158
+
159
+ def check_inputs(
160
+ self,
161
+ image: PipelineImageInput,
162
+ processing_resolution: int,
163
+ resample_method_input: str,
164
+ resample_method_output: str,
165
+ batch_size: int,
166
+ output_type: str,
167
+ ) -> int:
168
+ if processing_resolution is None:
169
+ raise ValueError(
170
+ "`processing_resolution` is not specified and could not be resolved from the model config."
171
+ )
172
+ if processing_resolution < 0:
173
+ raise ValueError(
174
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
175
+ "downsampled processing."
176
+ )
177
+ if processing_resolution % self.vae_scale_factor != 0:
178
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
179
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
180
+ raise ValueError(
181
+ "`resample_method_input` takes string values compatible with PIL library: "
182
+ "nearest, nearest-exact, bilinear, bicubic, area."
183
+ )
184
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
185
+ raise ValueError(
186
+ "`resample_method_output` takes string values compatible with PIL library: "
187
+ "nearest, nearest-exact, bilinear, bicubic, area."
188
+ )
189
+ if batch_size < 1:
190
+ raise ValueError("`batch_size` must be positive.")
191
+ if output_type not in ["pt", "np"]:
192
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
193
+
194
+ # image checks
195
+ num_images = 0
196
+ W, H = None, None
197
+ if not isinstance(image, list):
198
+ image = [image]
199
+ for i, img in enumerate(image):
200
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
201
+ if img.ndim not in (2, 3, 4):
202
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
203
+ H_i, W_i = img.shape[-2:]
204
+ N_i = 1
205
+ if img.ndim == 4:
206
+ N_i = img.shape[0]
207
+ elif isinstance(img, Image.Image):
208
+ W_i, H_i = img.size
209
+ N_i = 1
210
+ else:
211
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
212
+ if W is None:
213
+ W, H = W_i, H_i
214
+ elif (W, H) != (W_i, H_i):
215
+ raise ValueError(
216
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
217
+ )
218
+ num_images += N_i
219
+
220
+ return num_images
221
+
222
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
223
+ if not hasattr(self, "_progress_bar_config"):
224
+ self._progress_bar_config = {}
225
+ elif not isinstance(self._progress_bar_config, dict):
226
+ raise ValueError(
227
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
228
+ )
229
+
230
+ progress_bar_config = dict(**self._progress_bar_config)
231
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
232
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
233
+ if iterable is not None:
234
+ return tqdm(iterable, **progress_bar_config)
235
+ elif total is not None:
236
+ return tqdm(total=total, **progress_bar_config)
237
+ else:
238
+ raise ValueError("Either `total` or `iterable` has to be defined.")
239
+
240
+ @torch.no_grad()
241
+ def __call__(
242
+ self,
243
+ image: PipelineImageInput,
244
+ processing_resolution: Optional[int] = None,
245
+ match_input_resolution: bool = True,
246
+ resample_method_input: str = "bilinear",
247
+ resample_method_output: str = "bilinear",
248
+ batch_size: int = 1,
249
+ output_type: str = "np",
250
+ output_latent: bool = False,
251
+ return_dict: bool = True,
252
+ ):
253
+ """
254
+ Function invoked when calling the pipeline.
255
+
256
+ Args:
257
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
258
+ `List[torch.Tensor]`: An input image or images used as an input for the depth estimation task. For
259
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
260
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
261
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
262
+ same width and height.
263
+ processing_resolution (`int`, *optional*, defaults to `None`):
264
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
265
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
266
+ value `None` resolves to the optimal value from the model config.
267
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
268
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
269
+ side of the output will equal to `processing_resolution`.
270
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
271
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
272
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
273
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
274
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
275
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
276
+ batch_size (`int`, *optional*, defaults to `1`):
277
+ Batch size; only matters passing a tensor of images.
278
+ output_type (`str`, *optional*, defaults to `"np"`):
279
+ Preferred format of the output's `prediction` and the optional `uncertainty` fields. The accepted
280
+ values are: `"np"` (numpy array) or `"pt"` (torch tensor).
281
+ output_latent (`bool`, *optional*, defaults to `False`):
282
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
283
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
284
+ `latents` argument.
285
+ return_dict (`bool`, *optional*, defaults to `True`):
286
+ Whether or not to return a [`~pipelines.marigold.E2EMarigoldDepthOutput`] instead of a plain tuple.
287
+
288
+ # add
289
+ E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
290
+ """
291
+
292
+ # 0. Resolving variables.
293
+ device = self._execution_device
294
+ dtype = self.dtype
295
+
296
+ # Model-specific optimal default values leading to fast and reasonable results.
297
+ if processing_resolution is None:
298
+ processing_resolution = self.default_processing_resolution
299
+
300
+ # 1. Check inputs.
301
+ num_images = self.check_inputs(
302
+ image,
303
+ processing_resolution,
304
+ resample_method_input,
305
+ resample_method_output,
306
+ batch_size,
307
+ output_type,
308
+ )
309
+
310
+ # 2. Prepare empty text conditioning.
311
+ # Model invocation: self.tokenizer, self.text_encoder.
312
+ if self.empty_text_embedding is None:
313
+ prompt = ""
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="do_not_pad",
317
+ max_length=self.tokenizer.model_max_length,
318
+ truncation=True,
319
+ return_tensors="pt",
320
+ )
321
+ text_input_ids = text_inputs.input_ids.to(device)
322
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
323
+
324
+ # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
325
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
326
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
327
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
328
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
329
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
330
+ # resolution can lead to loss of either fine details or global context in the output predictions.
331
+ image, padding, original_resolution = self.image_processor.preprocess(
332
+ image, processing_resolution, resample_method_input, device, dtype
333
+ ) # [N,3,PPH,PPW]
334
+
335
+ # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
336
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
337
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
338
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
339
+ # into latent space and replicated `E` times. Encoding into latent space happens in batches of size `batch_size`.
340
+ # Model invocation: self.vae.encoder.
341
+ image_latent, pred_latent = self.prepare_latents(
342
+ image, batch_size
343
+ ) # [N*E,4,h,w], [N*E,4,h,w]
344
+
345
+ del image
346
+
347
+ batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat(
348
+ batch_size, 1, 1
349
+ ) # [B,1024,2]
350
+
351
+ # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`.
352
+ # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and
353
+ # outputs noise for the predicted modality's latent space. The number of denoising diffusion steps is defined by
354
+ # `num_inference_steps`. It is either set directly, or resolves to the optimal value specific to the loaded
355
+ # model.
356
+ # Model invocation: self.unet.
357
+ pred_latents = []
358
+
359
+ for i in self.progress_bar(
360
+ range(0, num_images, batch_size), leave=True, desc="E2E FT predictions..."
361
+ ):
362
+ batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w]
363
+ batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w]
364
+ effective_batch_size = batch_image_latent.shape[0]
365
+ text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024]
366
+
367
+ # add
368
+ # Single step inference for E2E FT models
369
+ self.scheduler.set_timesteps(1, device=device)
370
+ for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."):
371
+ batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w]
372
+ noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w]
373
+ batch_pred_latent = self.scheduler.step(
374
+ noise, t, batch_pred_latent
375
+ ).pred_original_sample # [B,4,h,w], # add
376
+ # directly take pred_original_sample rather than prev_sample
377
+
378
+ pred_latents.append(batch_pred_latent)
379
+
380
+ pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
381
+
382
+ del (
383
+ pred_latents,
384
+ image_latent,
385
+ batch_empty_text_embedding,
386
+ batch_image_latent,
387
+ batch_pred_latent,
388
+ text,
389
+ batch_latent,
390
+ noise,
391
+ )
392
+
393
+ # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
394
+ # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
395
+ # Model invocation: self.vae.decoder.
396
+ prediction = torch.cat(
397
+ [
398
+ self.decode_prediction(pred_latent[i : i + batch_size])
399
+ for i in range(0, pred_latent.shape[0], batch_size)
400
+ ],
401
+ dim=0,
402
+ ) # [N*E,1,PPH,PPW]
403
+
404
+ if not output_latent:
405
+ pred_latent = None
406
+
407
+ # 7. Remove padding. The output shape is (PH, PW).
408
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW]
409
+
410
+ # 9. If `match_input_resolution` is set, the output prediction and the uncertainty are upsampled to match the
411
+ # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled.
412
+ # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by
413
+ # setting the `resample_method_output` parameter (e.g., to `"nearest"`).
414
+ if match_input_resolution:
415
+ prediction = self.image_processor.resize_antialias(
416
+ prediction, original_resolution, resample_method_output, is_aa=False
417
+ ) # [N,1,H,W]
418
+
419
+ # 10. Prepare the final outputs.
420
+ if output_type == "np":
421
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1]
422
+
423
+ # 11. Offload all models
424
+ self.maybe_free_model_hooks()
425
+
426
+ if not return_dict:
427
+ return (prediction, pred_latent)
428
+
429
+ return E2EMarigoldDepthOutput(
430
+ prediction=prediction,
431
+ latent=pred_latent,
432
+ )
433
+
434
+ def prepare_latents(
435
+ self,
436
+ image: torch.Tensor,
437
+ batch_size: int,
438
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
439
+ def retrieve_latents(encoder_output):
440
+ if hasattr(encoder_output, "latent_dist"):
441
+ return encoder_output.latent_dist.mode()
442
+ elif hasattr(encoder_output, "latents"):
443
+ return encoder_output.latents
444
+ else:
445
+ raise AttributeError("Could not access latents of provided encoder_output")
446
+
447
+ image_latent = torch.cat(
448
+ [
449
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
450
+ for i in range(0, image.shape[0], batch_size)
451
+ ],
452
+ dim=0,
453
+ ) # [N,4,h,w]
454
+ image_latent = image_latent * self.vae.config.scaling_factor # [N*E,4,h,w]
455
+
456
+ # add
457
+ # provide zeros as noised latent
458
+ pred_latent = zeros_tensor(
459
+ image_latent.shape,
460
+ device=image_latent.device,
461
+ dtype=image_latent.dtype,
462
+ ) # [N*E,4,h,w]
463
+
464
+ return image_latent, pred_latent
465
+
466
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
467
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
468
+ raise ValueError(
469
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
470
+ )
471
+
472
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
473
+
474
+ prediction = prediction.mean(dim=1, keepdim=True) # [B,1,H,W]
475
+ prediction = torch.clip(prediction, -1.0, 1.0) # [B,1,H,W]
476
+ prediction = (prediction + 1.0) / 2.0
477
+
478
+ return prediction # [B,1,H,W]