ashawkey commited on
Commit
7521d91
1 Parent(s): d2af8f9

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline_mvdream.py +558 -0
pipeline_mvdream.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import inspect
4
+ import numpy as np
5
+ from typing import Callable, List, Optional, Union
6
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
7
+ from diffusers import AutoencoderKL, DiffusionPipeline
8
+ from diffusers.utils import (
9
+ deprecate,
10
+ is_accelerate_available,
11
+ is_accelerate_version,
12
+ logging,
13
+ )
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.schedulers import DDIMScheduler
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from mv_unet import MultiViewUNetModel, get_camera
19
+
20
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
+
22
+
23
+ class MVDreamPipeline(DiffusionPipeline):
24
+
25
+ _optional_components = ["feature_extractor", "image_encoder"]
26
+
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKL,
30
+ unet: MultiViewUNetModel,
31
+ tokenizer: CLIPTokenizer,
32
+ text_encoder: CLIPTextModel,
33
+ scheduler: DDIMScheduler,
34
+ # imagedream variant
35
+ feature_extractor: CLIPImageProcessor,
36
+ image_encoder: CLIPVisionModel,
37
+ requires_safety_checker: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
42
+ deprecation_message = (
43
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
44
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
45
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
46
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
47
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
48
+ " file"
49
+ )
50
+ deprecate(
51
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
52
+ )
53
+ new_config = dict(scheduler.config)
54
+ new_config["steps_offset"] = 1
55
+ scheduler._internal_dict = FrozenDict(new_config)
56
+
57
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
58
+ deprecation_message = (
59
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
60
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
61
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
62
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
63
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
64
+ )
65
+ deprecate(
66
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
67
+ )
68
+ new_config = dict(scheduler.config)
69
+ new_config["clip_sample"] = False
70
+ scheduler._internal_dict = FrozenDict(new_config)
71
+
72
+ self.register_modules(
73
+ vae=vae,
74
+ unet=unet,
75
+ scheduler=scheduler,
76
+ tokenizer=tokenizer,
77
+ text_encoder=text_encoder,
78
+ feature_extractor=feature_extractor,
79
+ image_encoder=image_encoder,
80
+ )
81
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
83
+
84
+ def enable_vae_slicing(self):
85
+ r"""
86
+ Enable sliced VAE decoding.
87
+
88
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
89
+ steps. This is useful to save some memory and allow larger batch sizes.
90
+ """
91
+ self.vae.enable_slicing()
92
+
93
+ def disable_vae_slicing(self):
94
+ r"""
95
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
96
+ computing decoding in one step.
97
+ """
98
+ self.vae.disable_slicing()
99
+
100
+ def enable_vae_tiling(self):
101
+ r"""
102
+ Enable tiled VAE decoding.
103
+
104
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
105
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
106
+ """
107
+ self.vae.enable_tiling()
108
+
109
+ def disable_vae_tiling(self):
110
+ r"""
111
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
112
+ computing decoding in one step.
113
+ """
114
+ self.vae.disable_tiling()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
119
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
120
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
121
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
122
+ `enable_model_cpu_offload`, but performance is lower.
123
+ """
124
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
125
+ from accelerate import cpu_offload
126
+ else:
127
+ raise ImportError(
128
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
129
+ )
130
+
131
+ device = torch.device(f"cuda:{gpu_id}")
132
+
133
+ if self.device.type != "cpu":
134
+ self.to("cpu", silence_dtype_warnings=True)
135
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
136
+
137
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
138
+ cpu_offload(cpu_offloaded_model, device)
139
+
140
+ def enable_model_cpu_offload(self, gpu_id=0):
141
+ r"""
142
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
143
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
144
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
145
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
146
+ """
147
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
148
+ from accelerate import cpu_offload_with_hook
149
+ else:
150
+ raise ImportError(
151
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
152
+ )
153
+
154
+ device = torch.device(f"cuda:{gpu_id}")
155
+
156
+ if self.device.type != "cpu":
157
+ self.to("cpu", silence_dtype_warnings=True)
158
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
159
+
160
+ hook = None
161
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
162
+ _, hook = cpu_offload_with_hook(
163
+ cpu_offloaded_model, device, prev_module_hook=hook
164
+ )
165
+
166
+ # We'll offload the last model manually.
167
+ self.final_offload_hook = hook
168
+
169
+ @property
170
+ def _execution_device(self):
171
+ r"""
172
+ Returns the device on which the pipeline's models will be executed. After calling
173
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
174
+ hooks.
175
+ """
176
+ if not hasattr(self.unet, "_hf_hook"):
177
+ return self.device
178
+ for module in self.unet.modules():
179
+ if (
180
+ hasattr(module, "_hf_hook")
181
+ and hasattr(module._hf_hook, "execution_device")
182
+ and module._hf_hook.execution_device is not None
183
+ ):
184
+ return torch.device(module._hf_hook.execution_device)
185
+ return self.device
186
+
187
+ def _encode_prompt(
188
+ self,
189
+ prompt,
190
+ device,
191
+ num_images_per_prompt,
192
+ do_classifier_free_guidance: bool,
193
+ negative_prompt=None,
194
+ ):
195
+ r"""
196
+ Encodes the prompt into text encoder hidden states.
197
+
198
+ Args:
199
+ prompt (`str` or `List[str]`, *optional*):
200
+ prompt to be encoded
201
+ device: (`torch.device`):
202
+ torch device
203
+ num_images_per_prompt (`int`):
204
+ number of images that should be generated per prompt
205
+ do_classifier_free_guidance (`bool`):
206
+ whether to use classifier free guidance or not
207
+ negative_prompt (`str` or `List[str]`, *optional*):
208
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
209
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
210
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
211
+ prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
216
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
217
+ argument.
218
+ """
219
+ if prompt is not None and isinstance(prompt, str):
220
+ batch_size = 1
221
+ elif prompt is not None and isinstance(prompt, list):
222
+ batch_size = len(prompt)
223
+ else:
224
+ raise ValueError(
225
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
226
+ )
227
+
228
+ text_inputs = self.tokenizer(
229
+ prompt,
230
+ padding="max_length",
231
+ max_length=self.tokenizer.model_max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+ text_input_ids = text_inputs.input_ids
236
+ untruncated_ids = self.tokenizer(
237
+ prompt, padding="longest", return_tensors="pt"
238
+ ).input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
241
+ text_input_ids, untruncated_ids
242
+ ):
243
+ removed_text = self.tokenizer.batch_decode(
244
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
245
+ )
246
+ logger.warning(
247
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
248
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
249
+ )
250
+
251
+ if (
252
+ hasattr(self.text_encoder.config, "use_attention_mask")
253
+ and self.text_encoder.config.use_attention_mask
254
+ ):
255
+ attention_mask = text_inputs.attention_mask.to(device)
256
+ else:
257
+ attention_mask = None
258
+
259
+ prompt_embeds = self.text_encoder(
260
+ text_input_ids.to(device),
261
+ attention_mask=attention_mask,
262
+ )
263
+ prompt_embeds = prompt_embeds[0]
264
+
265
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
+
267
+ bs_embed, seq_len, _ = prompt_embeds.shape
268
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(
271
+ bs_embed * num_images_per_prompt, seq_len, -1
272
+ )
273
+
274
+ # get unconditional embeddings for classifier free guidance
275
+ if do_classifier_free_guidance:
276
+ uncond_tokens: List[str]
277
+ if negative_prompt is None:
278
+ uncond_tokens = [""] * batch_size
279
+ elif type(prompt) is not type(negative_prompt):
280
+ raise TypeError(
281
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
+ f" {type(prompt)}."
283
+ )
284
+ elif isinstance(negative_prompt, str):
285
+ uncond_tokens = [negative_prompt]
286
+ elif batch_size != len(negative_prompt):
287
+ raise ValueError(
288
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
+ " the batch size of `prompt`."
291
+ )
292
+ else:
293
+ uncond_tokens = negative_prompt
294
+
295
+ max_length = prompt_embeds.shape[1]
296
+ uncond_input = self.tokenizer(
297
+ uncond_tokens,
298
+ padding="max_length",
299
+ max_length=max_length,
300
+ truncation=True,
301
+ return_tensors="pt",
302
+ )
303
+
304
+ if (
305
+ hasattr(self.text_encoder.config, "use_attention_mask")
306
+ and self.text_encoder.config.use_attention_mask
307
+ ):
308
+ attention_mask = uncond_input.attention_mask.to(device)
309
+ else:
310
+ attention_mask = None
311
+
312
+ negative_prompt_embeds = self.text_encoder(
313
+ uncond_input.input_ids.to(device),
314
+ attention_mask=attention_mask,
315
+ )
316
+ negative_prompt_embeds = negative_prompt_embeds[0]
317
+
318
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
319
+ seq_len = negative_prompt_embeds.shape[1]
320
+
321
+ negative_prompt_embeds = negative_prompt_embeds.to(
322
+ dtype=self.text_encoder.dtype, device=device
323
+ )
324
+
325
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
326
+ 1, num_images_per_prompt, 1
327
+ )
328
+ negative_prompt_embeds = negative_prompt_embeds.view(
329
+ batch_size * num_images_per_prompt, seq_len, -1
330
+ )
331
+
332
+ # For classifier free guidance, we need to do two forward passes.
333
+ # Here we concatenate the unconditional and text embeddings into a single batch
334
+ # to avoid doing two forward passes
335
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
336
+
337
+ return prompt_embeds
338
+
339
+ def decode_latents(self, latents):
340
+ latents = 1 / self.vae.config.scaling_factor * latents
341
+ image = self.vae.decode(latents).sample
342
+ image = (image / 2 + 0.5).clamp(0, 1)
343
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
344
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
+ return image
346
+
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(
354
+ inspect.signature(self.scheduler.step).parameters.keys()
355
+ )
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(
362
+ inspect.signature(self.scheduler.step).parameters.keys()
363
+ )
364
+ if accepts_generator:
365
+ extra_step_kwargs["generator"] = generator
366
+ return extra_step_kwargs
367
+
368
+ def prepare_latents(
369
+ self,
370
+ batch_size,
371
+ num_channels_latents,
372
+ height,
373
+ width,
374
+ dtype,
375
+ device,
376
+ generator,
377
+ latents=None,
378
+ ):
379
+ shape = (
380
+ batch_size,
381
+ num_channels_latents,
382
+ height // self.vae_scale_factor,
383
+ width // self.vae_scale_factor,
384
+ )
385
+ if isinstance(generator, list) and len(generator) != batch_size:
386
+ raise ValueError(
387
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
388
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
389
+ )
390
+
391
+ if latents is None:
392
+ latents = randn_tensor(
393
+ shape, generator=generator, device=device, dtype=dtype
394
+ )
395
+ else:
396
+ latents = latents.to(device)
397
+
398
+ # scale the initial noise by the standard deviation required by the scheduler
399
+ latents = latents * self.scheduler.init_noise_sigma
400
+ return latents
401
+
402
+ def encode_image(self, image, device, num_images_per_prompt):
403
+ dtype = next(self.image_encoder.parameters()).dtype
404
+
405
+ if image.dtype == np.float32:
406
+ image = (image * 255).astype(np.uint8)
407
+
408
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
409
+ image = image.to(device=device, dtype=dtype)
410
+
411
+ image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
412
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
413
+
414
+ return torch.zeros_like(image_embeds), image_embeds
415
+
416
+ def encode_image_latents(self, image, device, num_images_per_prompt):
417
+
418
+ dtype = next(self.image_encoder.parameters()).dtype
419
+
420
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
421
+ image = 2 * image - 1
422
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
423
+ image = image.to(dtype=dtype)
424
+
425
+ posterior = self.vae.encode(image).latent_dist
426
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
427
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
428
+
429
+ return torch.zeros_like(latents), latents
430
+
431
+ @torch.no_grad()
432
+ def __call__(
433
+ self,
434
+ prompt: str = "",
435
+ image: Optional[np.ndarray] = None,
436
+ height: int = 256,
437
+ width: int = 256,
438
+ num_inference_steps: int = 50,
439
+ guidance_scale: float = 7.0,
440
+ negative_prompt: str = "",
441
+ num_images_per_prompt: int = 1,
442
+ eta: float = 0.0,
443
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
444
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
445
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
446
+ callback_steps: int = 1,
447
+ num_frames: int = 4,
448
+ device=torch.device("cuda:0"),
449
+ ):
450
+ self.unet = self.unet.to(device=device)
451
+ self.vae = self.vae.to(device=device)
452
+ self.text_encoder = self.text_encoder.to(device=device)
453
+
454
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
455
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
456
+ # corresponds to doing no classifier free guidance.
457
+ do_classifier_free_guidance = guidance_scale > 1.0
458
+
459
+ # Prepare timesteps
460
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
461
+ timesteps = self.scheduler.timesteps
462
+
463
+ # imagedream variant
464
+ if image is not None:
465
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
466
+ self.image_encoder = self.image_encoder.to(device=device)
467
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
468
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
469
+
470
+ _prompt_embeds = self._encode_prompt(
471
+ prompt=prompt,
472
+ device=device,
473
+ num_images_per_prompt=num_images_per_prompt,
474
+ do_classifier_free_guidance=do_classifier_free_guidance,
475
+ negative_prompt=negative_prompt,
476
+ ) # type: ignore
477
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
478
+
479
+ # Prepare latent variables
480
+ actual_num_frames = num_frames if image is None else num_frames + 1
481
+ latents: torch.Tensor = self.prepare_latents(
482
+ actual_num_frames * num_images_per_prompt,
483
+ 4,
484
+ height,
485
+ width,
486
+ prompt_embeds_pos.dtype,
487
+ device,
488
+ generator,
489
+ None,
490
+ )
491
+
492
+ if image is not None:
493
+ camera = get_camera(num_frames, elevation=5, extra_view=True).to(dtype=latents.dtype, device=device)
494
+ else:
495
+ camera = get_camera(num_frames, elevation=15, extra_view=False).to(dtype=latents.dtype, device=device)
496
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
497
+
498
+ # Prepare extra step kwargs.
499
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
500
+
501
+ # Denoising loop
502
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
503
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
504
+ for i, t in enumerate(timesteps):
505
+ # expand the latents if we are doing classifier free guidance
506
+ multiplier = 2 if do_classifier_free_guidance else 1
507
+ latent_model_input = torch.cat([latents] * multiplier)
508
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
509
+
510
+ unet_inputs = {
511
+ 'x': latent_model_input,
512
+ 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
513
+ 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
514
+ 'num_frames': actual_num_frames,
515
+ 'camera': torch.cat([camera] * multiplier),
516
+ }
517
+
518
+ if image is not None:
519
+ unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
520
+ unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
521
+
522
+ # predict the noise residual
523
+ noise_pred = self.unet.forward(**unet_inputs)
524
+
525
+ # perform guidance
526
+ if do_classifier_free_guidance:
527
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
528
+ noise_pred = noise_pred_uncond + guidance_scale * (
529
+ noise_pred_text - noise_pred_uncond
530
+ )
531
+
532
+ # compute the previous noisy sample x_t -> x_t-1
533
+ latents: torch.Tensor = self.scheduler.step(
534
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
535
+ )[0]
536
+
537
+ # call the callback, if provided
538
+ if i == len(timesteps) - 1 or (
539
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
540
+ ):
541
+ progress_bar.update()
542
+ if callback is not None and i % callback_steps == 0:
543
+ callback(i, t, latents) # type: ignore
544
+
545
+ # Post-processing
546
+ if output_type == "latent":
547
+ image = latents
548
+ elif output_type == "pil":
549
+ image = self.decode_latents(latents)
550
+ image = self.numpy_to_pil(image)
551
+ else: # numpy
552
+ image = self.decode_latents(latents)
553
+
554
+ # Offload last model to CPU
555
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
556
+ self.final_offload_hook.offload()
557
+
558
+ return image