GonzaloMG commited on
Commit
3603084
1 Parent(s): 9d6f3ad

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +489 -0
pipeline.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Marigold authors, PRS ETH Zurich. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # --------------------------------------------------------------------------
16
+ # More information and citation instructions are available on the
17
+ # Marigold project website: https://marigoldmonodepth.github.io
18
+ # --------------------------------------------------------------------------
19
+
20
+ # @GonzaloMartinGarcia
21
+ # Inference Pipeline for End-to-End Marigold and Stable Diffusion Surface Normal Estimators
22
+ # ----------------------------------------------------------------------------------
23
+ # A streamlined version of the official MarigoldDepthPipeline from diffusers:
24
+ # https://github.com/huggingface/diffusers/blob/a98a839de75f1ad82d8d200c3bc2e4ff89929081/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py#L96
25
+ #
26
+ # This implementation is meant for use with the diffusers custom_pipeline feature.
27
+ # Modifications from the original code are marked with '# add' comments.
28
+
29
+ from dataclasses import dataclass
30
+ from typing import List, Optional, Tuple, Union
31
+
32
+ import numpy as np
33
+ import torch
34
+ from PIL import Image
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPTextModel, CLIPTokenizer
37
+
38
+ from diffusers.image_processor import PipelineImageInput
39
+ from diffusers.models import (
40
+ AutoencoderKL,
41
+ UNet2DConditionModel,
42
+ )
43
+ from diffusers.schedulers import (
44
+ DDIMScheduler,
45
+ )
46
+ from diffusers.utils import (
47
+ BaseOutput,
48
+ logging,
49
+ )
50
+ from diffusers import DiffusionPipeline
51
+ from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
52
+
53
+ # add
54
+ def zeros_tensor(
55
+ shape: Union[Tuple, List],
56
+ device: Optional["torch.device"] = None,
57
+ dtype: Optional["torch.dtype"] = None,
58
+ layout: Optional["torch.layout"] = None,
59
+ ):
60
+ """
61
+ A helper function to create tensors of zeros on the desired `device`.
62
+ Mirrors randn_tensor from diffusers.utils.torch_utils.
63
+ """
64
+ layout = layout or torch.strided
65
+ device = device or torch.device("cpu")
66
+ latents = torch.zeros(list(shape), dtype=dtype, layout=layout).to(device)
67
+ return latents
68
+
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+
73
+ @dataclass
74
+ class E2EMarigoldNormalsOutput(BaseOutput):
75
+ """
76
+ Output class for Marigold monocular normals prediction pipeline.
77
+
78
+ Args:
79
+ prediction (`np.ndarray`, `torch.Tensor`):
80
+ Predicted normals with values in the range [-1, 1]. The shape is always $numimages \times 3 \times height
81
+ \times width$, regardless of whether the images were passed as a 4D array or a list.
82
+ latent (`None`, `torch.Tensor`):
83
+ Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
84
+ The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
85
+ """
86
+
87
+ prediction: Union[np.ndarray, torch.Tensor]
88
+ latent: Union[None, torch.Tensor]
89
+
90
+
91
+ class E2EMarigoldNormalsPipeline(DiffusionPipeline):
92
+ """
93
+ # add
94
+ Pipeline for monocular normals estimation using the E2E FT Marigold and SD method: https://gonzalomartingarcia.github.io/diffusion-e2e-ft/
95
+ Implementation is built upon Marigold: https://marigoldmonodepth.github.io
96
+
97
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
98
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
99
+
100
+ Args:
101
+ unet (`UNet2DConditionModel`):
102
+ Conditional U-Net to denoise the normals latent, conditioned on image latent.
103
+ vae (`AutoencoderKL`):
104
+ Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
105
+ representations.
106
+ scheduler (`DDIMScheduler` or `LCMScheduler`):
107
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
108
+ text_encoder (`CLIPTextModel`):
109
+ Text-encoder, for empty text embedding.
110
+ tokenizer (`CLIPTokenizer`):
111
+ CLIP tokenizer.
112
+ default_processing_resolution (`int`, *optional*):
113
+ The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
114
+ the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
115
+ default value is used. This is required to ensure reasonable results with various model flavors trained
116
+ with varying optimal processing resolution values.
117
+ """
118
+
119
+ model_cpu_offload_seq = "text_encoder->unet->vae"
120
+
121
+ def __init__(
122
+ self,
123
+ unet: UNet2DConditionModel,
124
+ vae: AutoencoderKL,
125
+ scheduler: Union[DDIMScheduler],
126
+ text_encoder: CLIPTextModel,
127
+ tokenizer: CLIPTokenizer,
128
+ default_processing_resolution: Optional[int] = 768, # add
129
+ ):
130
+ super().__init__()
131
+
132
+ self.register_modules(
133
+ unet=unet,
134
+ vae=vae,
135
+ scheduler=scheduler,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ )
139
+ self.register_to_config(
140
+ default_processing_resolution=default_processing_resolution,
141
+ )
142
+
143
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
144
+ self.default_processing_resolution = default_processing_resolution
145
+ self.empty_text_embedding = None
146
+
147
+ self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
148
+
149
+ def check_inputs(
150
+ self,
151
+ image: PipelineImageInput,
152
+ processing_resolution: int,
153
+ resample_method_input: str,
154
+ resample_method_output: str,
155
+ batch_size: int,
156
+ output_type: str,
157
+ ) -> int:
158
+ if processing_resolution is None:
159
+ raise ValueError(
160
+ "`processing_resolution` is not specified and could not be resolved from the model config."
161
+ )
162
+ if processing_resolution < 0:
163
+ raise ValueError(
164
+ "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
165
+ "downsampled processing."
166
+ )
167
+ if processing_resolution % self.vae_scale_factor != 0:
168
+ raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
169
+ if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
170
+ raise ValueError(
171
+ "`resample_method_input` takes string values compatible with PIL library: "
172
+ "nearest, nearest-exact, bilinear, bicubic, area."
173
+ )
174
+ if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
175
+ raise ValueError(
176
+ "`resample_method_output` takes string values compatible with PIL library: "
177
+ "nearest, nearest-exact, bilinear, bicubic, area."
178
+ )
179
+ if batch_size < 1:
180
+ raise ValueError("`batch_size` must be positive.")
181
+ if output_type not in ["pt", "np"]:
182
+ raise ValueError("`output_type` must be one of `pt` or `np`.")
183
+
184
+ # image checks
185
+ num_images = 0
186
+ W, H = None, None
187
+ if not isinstance(image, list):
188
+ image = [image]
189
+ for i, img in enumerate(image):
190
+ if isinstance(img, np.ndarray) or torch.is_tensor(img):
191
+ if img.ndim not in (2, 3, 4):
192
+ raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
193
+ H_i, W_i = img.shape[-2:]
194
+ N_i = 1
195
+ if img.ndim == 4:
196
+ N_i = img.shape[0]
197
+ elif isinstance(img, Image.Image):
198
+ W_i, H_i = img.size
199
+ N_i = 1
200
+ else:
201
+ raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
202
+ if W is None:
203
+ W, H = W_i, H_i
204
+ elif (W, H) != (W_i, H_i):
205
+ raise ValueError(
206
+ f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
207
+ )
208
+ num_images += N_i
209
+
210
+ if processing_resolution > 0:
211
+ max_orig = max(H, W)
212
+ new_H = H * processing_resolution // max_orig
213
+ new_W = W * processing_resolution // max_orig
214
+ if new_H == 0 or new_W == 0:
215
+ raise ValueError(f"Extreme aspect ratio of the input image: [{W} x {H}]")
216
+ W, H = new_W, new_H
217
+ w = (W + self.vae_scale_factor - 1) // self.vae_scale_factor
218
+ h = (H + self.vae_scale_factor - 1) // self.vae_scale_factor
219
+ shape_expected = (num_images, self.vae.config.latent_channels, h, w)
220
+
221
+ return num_images
222
+
223
+ def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
224
+ if not hasattr(self, "_progress_bar_config"):
225
+ self._progress_bar_config = {}
226
+ elif not isinstance(self._progress_bar_config, dict):
227
+ raise ValueError(
228
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
229
+ )
230
+
231
+ progress_bar_config = dict(**self._progress_bar_config)
232
+ progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
233
+ progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
234
+ if iterable is not None:
235
+ return tqdm(iterable, **progress_bar_config)
236
+ elif total is not None:
237
+ return tqdm(total=total, **progress_bar_config)
238
+ else:
239
+ raise ValueError("Either `total` or `iterable` has to be defined.")
240
+
241
+ @torch.no_grad()
242
+ def __call__(
243
+ self,
244
+ image: PipelineImageInput,
245
+ processing_resolution: Optional[int] = None,
246
+ match_input_resolution: bool = True,
247
+ resample_method_input: str = "bilinear",
248
+ resample_method_output: str = "bilinear",
249
+ batch_size: int = 1,
250
+ output_type: str = "np",
251
+ output_latent: bool = False,
252
+ return_dict: bool = True,
253
+ ):
254
+ """
255
+ Function invoked when calling the pipeline.
256
+
257
+ Args:
258
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
259
+ `List[torch.Tensor]`: An input image or images used as an input for the normals estimation task. For
260
+ arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
261
+ by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
262
+ three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
263
+ same width and height.
264
+ processing_resolution (`int`, *optional*, defaults to `None`):
265
+ Effective processing resolution. When set to `0`, matches the larger input image dimension. This
266
+ produces crisper predictions, but may also lead to the overall loss of global context. The default
267
+ value `None` resolves to the optimal value from the model config.
268
+ match_input_resolution (`bool`, *optional*, defaults to `True`):
269
+ When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
270
+ side of the output will equal to `processing_resolution`.
271
+ resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
272
+ Resampling method used to resize input images to `processing_resolution`. The accepted values are:
273
+ `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
274
+ resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
275
+ Resampling method used to resize output predictions to match the input resolution. The accepted values
276
+ are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
277
+ batch_size (`int`, *optional*, defaults to `1`):
278
+ Batch size; only matters when passing a tensor of images.
279
+ output_type (`str`, *optional*, defaults to `"np"`):
280
+ Preferred format of the output's `prediction`. The accepted values are: `"np"` (numpy array) or `"pt"` (torch tensor).
281
+ output_latent (`bool`, *optional*, defaults to `False`):
282
+ When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
283
+ within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
284
+ `latents` argument.
285
+ return_dict (`bool`, *optional*, defaults to `True`):
286
+ Whether or not to return a [`~pipelines.marigold.MarigoldDepthOutput`] instead of a plain tuple.
287
+
288
+ # add
289
+ E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
290
+ """
291
+
292
+ # 0. Resolving variables.
293
+ device = self._execution_device
294
+ dtype = self.dtype
295
+
296
+ # Model-specific optimal default values leading to fast and reasonable results.
297
+ if processing_resolution is None:
298
+ processing_resolution = self.default_processing_resolution
299
+
300
+ # 1. Check inputs.
301
+ num_images = self.check_inputs(
302
+ image,
303
+ processing_resolution,
304
+ resample_method_input,
305
+ resample_method_output,
306
+ batch_size,
307
+ output_type,
308
+ )
309
+
310
+ # 2. Prepare empty text conditioning.
311
+ # Model invocation: self.tokenizer, self.text_encoder.
312
+ if self.empty_text_embedding is None:
313
+ prompt = ""
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="do_not_pad",
317
+ max_length=self.tokenizer.model_max_length,
318
+ truncation=True,
319
+ return_tensors="pt",
320
+ )
321
+ text_input_ids = text_inputs.input_ids.to(device)
322
+ self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
323
+
324
+ # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
325
+ # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
326
+ # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
327
+ # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
328
+ # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
329
+ # operation and leads to the most reasonable results. Using the native image resolution or any other processing
330
+ # resolution can lead to loss of either fine details or global context in the output predictions.
331
+ image, padding, original_resolution = self.image_processor.preprocess(
332
+ image, processing_resolution, resample_method_input, device, dtype
333
+ ) # [N,3,PPH,PPW]
334
+
335
+ # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
336
+ # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
337
+ # Latents of each such predictions across all input images and all ensemble members are represented in the
338
+ # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
339
+ # into latent space and replicated `E` times. Encoding into latent space happens in batches of size `batch_size`.
340
+ # Model invocation: self.vae.encoder.
341
+ image_latent, pred_latent = self.prepare_latents(
342
+ image, batch_size
343
+ ) # [N*E,4,h,w], [N*E,4,h,w]
344
+
345
+ del image
346
+
347
+ batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat(
348
+ batch_size, 1, 1
349
+ ) # [B,1024,2]
350
+
351
+ # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`.
352
+ # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and
353
+ # outputs noise for the predicted modality's latent space.
354
+ # Model invocation: self.unet.
355
+ pred_latents = []
356
+
357
+ for i in self.progress_bar(
358
+ range(0, num_images, batch_size), leave=True, desc="E2E FT predictions..."
359
+ ):
360
+ batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w]
361
+ batch_pred_latent = pred_latent[i : i + batch_size] # [B,4,h,w]
362
+ effective_batch_size = batch_image_latent.shape[0]
363
+ text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024]
364
+
365
+ # add
366
+ # Single step inference for E2E FT models
367
+ self.scheduler.set_timesteps(1, device=device)
368
+ for t in self.progress_bar(self.scheduler.timesteps, leave=False, desc="Diffusion steps..."):
369
+ batch_latent = torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w]
370
+ noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w]
371
+ batch_pred_latent = self.scheduler.step(
372
+ noise, t, batch_pred_latent
373
+ ).pred_original_sample # [B,4,h,w], # add
374
+ # directly take pred_original_sample rather than prev_sample
375
+
376
+ pred_latents.append(batch_pred_latent)
377
+
378
+ pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
379
+
380
+ del (
381
+ pred_latents,
382
+ image_latent,
383
+ batch_empty_text_embedding,
384
+ batch_image_latent,
385
+ batch_pred_latent,
386
+ text,
387
+ batch_latent,
388
+ noise,
389
+ )
390
+
391
+ # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
392
+ # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
393
+ # Model invocation: self.vae.decoder.
394
+ prediction = torch.cat(
395
+ [
396
+ self.decode_prediction(pred_latent[i : i + batch_size])
397
+ for i in range(0, pred_latent.shape[0], batch_size)
398
+ ],
399
+ dim=0,
400
+ ) # [N*E,3,PPH,PPW]
401
+
402
+ if not output_latent:
403
+ pred_latent = None
404
+
405
+ # 7. Remove padding. The output shape is (PH, PW).
406
+ prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,3,PH,PW]
407
+
408
+ # 8. If `match_input_resolution` is set, the output prediction is upsampled to match the
409
+ # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled.
410
+ # After upsampling, the native resolution normal maps are renormalized to unit length to reduce the artifacts.
411
+ # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by
412
+ # setting the `resample_method_output` parameter (e.g., to `"nearest"`).
413
+ if match_input_resolution:
414
+ prediction = self.image_processor.resize_antialias(
415
+ prediction, original_resolution, resample_method_output, is_aa=False
416
+ ) # [N,3,H,W]
417
+ prediction = self.normalize_normals(prediction) # [N,3,H,W]
418
+
419
+ # 10. Prepare the final outputs.
420
+ if output_type == "np":
421
+ prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,3]
422
+
423
+ # 11. Offload all models
424
+ self.maybe_free_model_hooks()
425
+
426
+ if not return_dict:
427
+ return (prediction, pred_latent)
428
+
429
+ return E2EMarigoldNormalsOutput(
430
+ prediction=prediction,
431
+ latent=pred_latent,
432
+ )
433
+
434
+ # Copied from diffusers.pipelines.marigold.pipeline_marigold_depth.MarigoldDepthPipeline.prepare_latents
435
+ def prepare_latents(
436
+ self,
437
+ image: torch.Tensor,
438
+ batch_size: int,
439
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
440
+ def retrieve_latents(encoder_output):
441
+ if hasattr(encoder_output, "latent_dist"):
442
+ return encoder_output.latent_dist.mode()
443
+ elif hasattr(encoder_output, "latents"):
444
+ return encoder_output.latents
445
+ else:
446
+ raise AttributeError("Could not access latents of provided encoder_output")
447
+
448
+ image_latent = torch.cat(
449
+ [
450
+ retrieve_latents(self.vae.encode(image[i : i + batch_size]))
451
+ for i in range(0, image.shape[0], batch_size)
452
+ ],
453
+ dim=0,
454
+ ) # [N,4,h,w]
455
+ image_latent = image_latent * self.vae.config.scaling_factor # [N*E,4,h,w]
456
+
457
+ # add
458
+ # provide zeros as noised latent
459
+ pred_latent = zeros_tensor(
460
+ image_latent.shape,
461
+ device=image_latent.device,
462
+ dtype=image_latent.dtype,
463
+ ) # [N*E,4,h,w]
464
+
465
+ return image_latent, pred_latent
466
+
467
+ def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
468
+ if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
469
+ raise ValueError(
470
+ f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
471
+ )
472
+
473
+ prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
474
+
475
+ # add
476
+ prediction = self.normalize_normals(prediction) # [B,3,H,W]
477
+ prediction = torch.clip(prediction, -1.0, 1.0)
478
+
479
+ return prediction # [B,3,H,W]
480
+
481
+ @staticmethod
482
+ def normalize_normals(normals: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
483
+ if normals.dim() != 4 or normals.shape[1] != 3:
484
+ raise ValueError(f"Expecting 4D tensor of shape [B,3,H,W]; got {normals.shape}.")
485
+
486
+ norm = torch.norm(normals, dim=1, keepdim=True)
487
+ normals /= norm.clamp(min=eps)
488
+
489
+ return normals