sayakpaul HF staff commited on
Commit
860a8cc
1 Parent(s): 306b03a

Upload pipeline_t2v_base_pixel.py

Browse files
Files changed (1) hide show
  1. pipeline_t2v_base_pixel.py +835 -0
pipeline_t2v_base_pixel.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Show Labs, Alibaba DAMO-VILAB, and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import html
17
+ import inspect
18
+ import re
19
+ import urllib.parse as ul
20
+ from dataclasses import dataclass
21
+ from typing import Any, Callable, Dict, List, Optional, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
27
+
28
+ from diffusers import UNet3DConditionModel
29
+ from diffusers.loaders import LoraLoaderMixin
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.schedulers import DDPMScheduler
32
+ from diffusers.utils import (
33
+ BACKENDS_MAPPING,
34
+ BaseOutput,
35
+ is_accelerate_available,
36
+ is_accelerate_version,
37
+ is_bs4_available,
38
+ is_ftfy_available,
39
+ logging,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+ if is_bs4_available():
47
+ from bs4 import BeautifulSoup
48
+
49
+ if is_ftfy_available():
50
+ import ftfy
51
+
52
+
53
+ @dataclass
54
+ class TextToVideoPipelineOutput(BaseOutput):
55
+ """
56
+ Output class for text to video pipelines.
57
+
58
+ Args:
59
+ frames (`List[np.ndarray]` or `torch.FloatTensor`)
60
+ List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
61
+ a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
62
+ denotes the video length i.e., the number of frames.
63
+ """
64
+
65
+ frames: Union[List[np.ndarray], torch.FloatTensor]
66
+
67
+
68
+ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
69
+ # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
70
+ # reshape to ncfhw
71
+ mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
72
+ std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
73
+ # unnormalize back to [0,1]
74
+ video = video.mul_(std).add_(mean)
75
+ video.clamp_(0, 1)
76
+ # prepare the final outputs
77
+ i, c, f, h, w = video.shape
78
+ images = video.permute(2, 3, 0, 4, 1).reshape(
79
+ f, h, i * w, c
80
+ ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
81
+ images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
82
+ images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
83
+ return images
84
+
85
+
86
+ class TextToVideoIFPipeline(DiffusionPipeline, LoraLoaderMixin):
87
+ tokenizer: T5Tokenizer
88
+ text_encoder: T5EncoderModel
89
+
90
+ unet: UNet3DConditionModel
91
+ scheduler: DDPMScheduler
92
+
93
+ feature_extractor: Optional[CLIPImageProcessor]
94
+ # safety_checker: Optional[IFSafetyChecker]
95
+
96
+ # watermarker: Optional[IFWatermarker]
97
+
98
+ bad_punct_regex = re.compile(
99
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
100
+ ) # noqa
101
+
102
+ _optional_components = [
103
+ "tokenizer",
104
+ "text_encoder",
105
+ "safety_checker",
106
+ "feature_extractor",
107
+ "watermarker",
108
+ ]
109
+
110
+ def __init__(
111
+ self,
112
+ tokenizer: T5Tokenizer,
113
+ text_encoder: T5EncoderModel,
114
+ unet: UNet3DConditionModel,
115
+ scheduler: DDPMScheduler,
116
+ feature_extractor: Optional[CLIPImageProcessor],
117
+ ):
118
+ super().__init__()
119
+
120
+ self.register_modules(
121
+ tokenizer=tokenizer,
122
+ text_encoder=text_encoder,
123
+ unet=unet,
124
+ scheduler=scheduler,
125
+ feature_extractor=feature_extractor,
126
+ )
127
+ self.safety_checker = None
128
+
129
+ def enable_sequential_cpu_offload(self, gpu_id=0):
130
+ r"""
131
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
132
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
133
+ when their specific submodule has its `forward` method called.
134
+ """
135
+ if is_accelerate_available():
136
+ from accelerate import cpu_offload
137
+ else:
138
+ raise ImportError("Please install accelerate via `pip install accelerate`")
139
+
140
+ device = torch.device(f"cuda:{gpu_id}")
141
+
142
+ models = [
143
+ self.text_encoder,
144
+ self.unet,
145
+ ]
146
+ for cpu_offloaded_model in models:
147
+ if cpu_offloaded_model is not None:
148
+ cpu_offload(cpu_offloaded_model, device)
149
+
150
+ if self.safety_checker is not None:
151
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
152
+
153
+ def enable_model_cpu_offload(self, gpu_id=0):
154
+ r"""
155
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
156
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
157
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
158
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
159
+ """
160
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
161
+ from accelerate import cpu_offload_with_hook
162
+ else:
163
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
164
+
165
+ device = torch.device(f"cuda:{gpu_id}")
166
+
167
+ self.unet.train()
168
+
169
+ if self.device.type != "cpu":
170
+ self.to("cpu", silence_dtype_warnings=True)
171
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
172
+
173
+ hook = None
174
+
175
+ if self.text_encoder is not None:
176
+ _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
177
+
178
+ # Accelerate will move the next model to the device _before_ calling the offload hook of the
179
+ # previous model. This will cause both models to be present on the device at the same time.
180
+ # IF uses T5 for its text encoder which is really large. We can manually call the offload
181
+ # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
182
+ # the GPU.
183
+ self.text_encoder_offload_hook = hook
184
+
185
+ _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
186
+
187
+ # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
188
+ self.unet_offload_hook = hook
189
+
190
+ if self.safety_checker is not None:
191
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
192
+
193
+ # We'll offload the last model manually.
194
+ self.final_offload_hook = hook
195
+
196
+ def remove_all_hooks(self):
197
+ if is_accelerate_available():
198
+ from accelerate.hooks import remove_hook_from_module
199
+ else:
200
+ raise ImportError("Please install accelerate via `pip install accelerate`")
201
+
202
+ for model in [self.text_encoder, self.unet, self.safety_checker]:
203
+ if model is not None:
204
+ remove_hook_from_module(model, recurse=True)
205
+
206
+ self.unet_offload_hook = None
207
+ self.text_encoder_offload_hook = None
208
+ self.final_offload_hook = None
209
+
210
+ @property
211
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
212
+ def _execution_device(self):
213
+ r"""
214
+ Returns the device on which the pipeline's models will be executed. After calling
215
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
216
+ hooks.
217
+ """
218
+ if not hasattr(self.unet, "_hf_hook"):
219
+ return self.device
220
+ for module in self.unet.modules():
221
+ if (
222
+ hasattr(module, "_hf_hook")
223
+ and hasattr(module._hf_hook, "execution_device")
224
+ and module._hf_hook.execution_device is not None
225
+ ):
226
+ return torch.device(module._hf_hook.execution_device)
227
+ return self.device
228
+
229
+ @torch.no_grad()
230
+ def encode_prompt(
231
+ self,
232
+ prompt,
233
+ do_classifier_free_guidance=True,
234
+ num_images_per_prompt=1,
235
+ device=None,
236
+ negative_prompt=None,
237
+ prompt_embeds: Optional[torch.FloatTensor] = None,
238
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
239
+ clean_caption: bool = False,
240
+ ):
241
+ r"""
242
+ Encodes the prompt into text encoder hidden states.
243
+
244
+ Args:
245
+ prompt (`str` or `List[str]`, *optional*):
246
+ prompt to be encoded
247
+ device: (`torch.device`, *optional*):
248
+ torch device to place the resulting embeddings on
249
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
250
+ number of images that should be generated per prompt
251
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
252
+ whether to use classifier free guidance or not
253
+ negative_prompt (`str` or `List[str]`, *optional*):
254
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
255
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
256
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
257
+ prompt_embeds (`torch.FloatTensor`, *optional*):
258
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
259
+ provided, text embeddings will be generated from `prompt` input argument.
260
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
261
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
262
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
263
+ argument.
264
+ """
265
+ if prompt is not None and negative_prompt is not None:
266
+ if type(prompt) is not type(negative_prompt):
267
+ raise TypeError(
268
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
269
+ f" {type(prompt)}."
270
+ )
271
+
272
+ if device is None:
273
+ device = self._execution_device
274
+
275
+ if prompt is not None and isinstance(prompt, str):
276
+ batch_size = 1
277
+ elif prompt is not None and isinstance(prompt, list):
278
+ batch_size = len(prompt)
279
+ else:
280
+ batch_size = prompt_embeds.shape[0]
281
+
282
+ # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
283
+ max_length = 77
284
+
285
+ if prompt_embeds is None:
286
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
287
+ text_inputs = self.tokenizer(
288
+ prompt,
289
+ padding="max_length",
290
+ max_length=max_length,
291
+ truncation=True,
292
+ add_special_tokens=True,
293
+ return_tensors="pt",
294
+ )
295
+ text_input_ids = text_inputs.input_ids
296
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
297
+
298
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
299
+ text_input_ids, untruncated_ids
300
+ ):
301
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
302
+ logger.warning(
303
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
304
+ f" {max_length} tokens: {removed_text}"
305
+ )
306
+
307
+ attention_mask = text_inputs.attention_mask.to(device)
308
+
309
+ prompt_embeds = self.text_encoder(
310
+ text_input_ids.to(device),
311
+ attention_mask=attention_mask,
312
+ )
313
+ prompt_embeds = prompt_embeds[0]
314
+
315
+ if self.text_encoder is not None:
316
+ dtype = self.text_encoder.dtype
317
+ elif self.unet is not None:
318
+ dtype = self.unet.dtype
319
+ else:
320
+ dtype = None
321
+
322
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
323
+
324
+ bs_embed, seq_len, _ = prompt_embeds.shape
325
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
326
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
327
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
328
+
329
+ # get unconditional embeddings for classifier free guidance
330
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
331
+ uncond_tokens: List[str]
332
+ if negative_prompt is None:
333
+ uncond_tokens = [""] * batch_size
334
+ elif isinstance(negative_prompt, str):
335
+ uncond_tokens = [negative_prompt]
336
+ elif batch_size != len(negative_prompt):
337
+ raise ValueError(
338
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
339
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
340
+ " the batch size of `prompt`."
341
+ )
342
+ else:
343
+ uncond_tokens = negative_prompt
344
+
345
+ uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
346
+ max_length = prompt_embeds.shape[1]
347
+ uncond_input = self.tokenizer(
348
+ uncond_tokens,
349
+ padding="max_length",
350
+ max_length=max_length,
351
+ truncation=True,
352
+ return_attention_mask=True,
353
+ add_special_tokens=True,
354
+ return_tensors="pt",
355
+ )
356
+ attention_mask = uncond_input.attention_mask.to(device)
357
+
358
+ negative_prompt_embeds = self.text_encoder(
359
+ uncond_input.input_ids.to(device),
360
+ attention_mask=attention_mask,
361
+ )
362
+ negative_prompt_embeds = negative_prompt_embeds[0]
363
+
364
+ if do_classifier_free_guidance:
365
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
366
+ seq_len = negative_prompt_embeds.shape[1]
367
+
368
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
369
+
370
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
371
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
372
+
373
+ # For classifier free guidance, we need to do two forward passes.
374
+ # Here we concatenate the unconditional and text embeddings into a single batch
375
+ # to avoid doing two forward passes
376
+ else:
377
+ negative_prompt_embeds = None
378
+
379
+ return prompt_embeds, negative_prompt_embeds
380
+
381
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
382
+ def prepare_extra_step_kwargs(self, generator, eta):
383
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
384
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
385
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
386
+ # and should be between [0, 1]
387
+
388
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
389
+ extra_step_kwargs = {}
390
+ if accepts_eta:
391
+ extra_step_kwargs["eta"] = eta
392
+
393
+ # check if the scheduler accepts generator
394
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
395
+ if accepts_generator:
396
+ extra_step_kwargs["generator"] = generator
397
+ return extra_step_kwargs
398
+
399
+ def check_inputs(
400
+ self,
401
+ prompt,
402
+ callback_steps,
403
+ negative_prompt=None,
404
+ prompt_embeds=None,
405
+ negative_prompt_embeds=None,
406
+ ):
407
+ if (callback_steps is None) or (
408
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
409
+ ):
410
+ raise ValueError(
411
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
412
+ f" {type(callback_steps)}."
413
+ )
414
+
415
+ if prompt is not None and prompt_embeds is not None:
416
+ raise ValueError(
417
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
418
+ " only forward one of the two."
419
+ )
420
+ elif prompt is None and prompt_embeds is None:
421
+ raise ValueError(
422
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
423
+ )
424
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
425
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
426
+
427
+ if negative_prompt is not None and negative_prompt_embeds is not None:
428
+ raise ValueError(
429
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
430
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
431
+ )
432
+
433
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
434
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
435
+ raise ValueError(
436
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
437
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
438
+ f" {negative_prompt_embeds.shape}."
439
+ )
440
+
441
+ def prepare_intermediate_images(
442
+ self,
443
+ batch_size,
444
+ num_channels,
445
+ num_frames,
446
+ height,
447
+ width,
448
+ dtype,
449
+ device,
450
+ generator,
451
+ ):
452
+ shape = (batch_size, num_channels, num_frames, height, width)
453
+ if isinstance(generator, list) and len(generator) != batch_size:
454
+ raise ValueError(
455
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
456
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
457
+ )
458
+
459
+ intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
460
+
461
+ # scale the initial noise by the standard deviation required by the scheduler
462
+ intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
463
+ return intermediate_images
464
+
465
+ def _text_preprocessing(self, text, clean_caption=False):
466
+ if clean_caption and not is_bs4_available():
467
+ logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
468
+ logger.warn("Setting `clean_caption` to False...")
469
+ clean_caption = False
470
+
471
+ if clean_caption and not is_ftfy_available():
472
+ logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
473
+ logger.warn("Setting `clean_caption` to False...")
474
+ clean_caption = False
475
+
476
+ if not isinstance(text, (tuple, list)):
477
+ text = [text]
478
+
479
+ def process(text: str):
480
+ if clean_caption:
481
+ text = self._clean_caption(text)
482
+ text = self._clean_caption(text)
483
+ else:
484
+ text = text.lower().strip()
485
+ return text
486
+
487
+ return [process(t) for t in text]
488
+
489
+ def _clean_caption(self, caption):
490
+ caption = str(caption)
491
+ caption = ul.unquote_plus(caption)
492
+ caption = caption.strip().lower()
493
+ caption = re.sub("<person>", "person", caption)
494
+ # urls:
495
+ caption = re.sub(
496
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
497
+ "",
498
+ caption,
499
+ ) # regex for urls
500
+ caption = re.sub(
501
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
502
+ "",
503
+ caption,
504
+ ) # regex for urls
505
+ # html:
506
+ caption = BeautifulSoup(caption, features="html.parser").text
507
+
508
+ # @<nickname>
509
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
510
+
511
+ # 31C0—31EF CJK Strokes
512
+ # 31F0—31FF Katakana Phonetic Extensions
513
+ # 3200—32FF Enclosed CJK Letters and Months
514
+ # 3300—33FF CJK Compatibility
515
+ # 3400—4DBF CJK Unified Ideographs Extension A
516
+ # 4DC0—4DFF Yijing Hexagram Symbols
517
+ # 4E00—9FFF CJK Unified Ideographs
518
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
519
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
520
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
521
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
522
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
523
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
524
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
525
+ #######################################################
526
+
527
+ # все виды тире / all types of dash --> "-"
528
+ caption = re.sub(
529
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
530
+ "-",
531
+ caption,
532
+ )
533
+
534
+ # кавычки к одному стандарту
535
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
536
+ caption = re.sub(r"[‘’]", "'", caption)
537
+
538
+ # &quot;
539
+ caption = re.sub(r"&quot;?", "", caption)
540
+ # &amp
541
+ caption = re.sub(r"&amp", "", caption)
542
+
543
+ # ip adresses:
544
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
545
+
546
+ # article ids:
547
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
548
+
549
+ # \n
550
+ caption = re.sub(r"\\n", " ", caption)
551
+
552
+ # "#123"
553
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
554
+ # "#12345.."
555
+ caption = re.sub(r"#\d{5,}\b", "", caption)
556
+ # "123456.."
557
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
558
+ # filenames:
559
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
560
+
561
+ #
562
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
563
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
564
+
565
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
566
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
567
+
568
+ # this-is-my-cute-cat / this_is_my_cute_cat
569
+ regex2 = re.compile(r"(?:\-|\_)")
570
+ if len(re.findall(regex2, caption)) > 3:
571
+ caption = re.sub(regex2, " ", caption)
572
+
573
+ caption = ftfy.fix_text(caption)
574
+ caption = html.unescape(html.unescape(caption))
575
+
576
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
577
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
578
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
579
+
580
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
581
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
582
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
583
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
584
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
585
+
586
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
587
+
588
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
589
+
590
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
591
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
592
+ caption = re.sub(r"\s+", " ", caption)
593
+
594
+ caption.strip()
595
+
596
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
597
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
598
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
599
+ caption = re.sub(r"^\.\S+$", "", caption)
600
+
601
+ return caption.strip()
602
+
603
+ @torch.no_grad()
604
+ def __call__(
605
+ self,
606
+ prompt: Union[str, List[str]] = None,
607
+ num_inference_steps: int = 100,
608
+ timesteps: List[int] = None,
609
+ guidance_scale: float = 7.0,
610
+ negative_prompt: Optional[Union[str, List[str]]] = None,
611
+ num_images_per_prompt: Optional[int] = 1,
612
+ height: Optional[int] = None,
613
+ width: Optional[int] = None,
614
+ num_frames: int = 16,
615
+ eta: float = 0.0,
616
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
617
+ prompt_embeds: Optional[torch.FloatTensor] = None,
618
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
619
+ output_type: Optional[str] = "np",
620
+ return_dict: bool = True,
621
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
622
+ callback_steps: int = 1,
623
+ clean_caption: bool = True,
624
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
625
+ ):
626
+ """
627
+ Function invoked when calling the pipeline for generation.
628
+
629
+ Args:
630
+ prompt (`str` or `List[str]`, *optional*):
631
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
632
+ instead.
633
+ num_inference_steps (`int`, *optional*, defaults to 50):
634
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
635
+ expense of slower inference.
636
+ timesteps (`List[int]`, *optional*):
637
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
638
+ timesteps are used. Must be in descending order.
639
+ guidance_scale (`float`, *optional*, defaults to 7.5):
640
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
641
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
642
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
643
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
644
+ usually at the expense of lower image quality.
645
+ negative_prompt (`str` or `List[str]`, *optional*):
646
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
647
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
648
+ less than `1`).
649
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
650
+ The number of images to generate per prompt.
651
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
652
+ The height in pixels of the generated image.
653
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
654
+ The width in pixels of the generated image.
655
+ eta (`float`, *optional*, defaults to 0.0):
656
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
657
+ [`schedulers.DDIMScheduler`], will be ignored for others.
658
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
659
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
660
+ to make generation deterministic.
661
+ prompt_embeds (`torch.FloatTensor`, *optional*):
662
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
663
+ provided, text embeddings will be generated from `prompt` input argument.
664
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
665
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
666
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
667
+ argument.
668
+ output_type (`str`, *optional*, defaults to `"pil"`):
669
+ The output format of the generate image. Choose between
670
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
671
+ return_dict (`bool`, *optional*, defaults to `True`):
672
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
673
+ callback (`Callable`, *optional*):
674
+ A function that will be called every `callback_steps` steps during inference. The function will be
675
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
676
+ callback_steps (`int`, *optional*, defaults to 1):
677
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
678
+ called at every step.
679
+ clean_caption (`bool`, *optional*, defaults to `True`):
680
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
681
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
682
+ prompt.
683
+ cross_attention_kwargs (`dict`, *optional*):
684
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
685
+ `self.processor` in
686
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
687
+
688
+ Examples:
689
+
690
+ Returns:
691
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
692
+ [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
693
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
694
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
695
+ or watermarked content, according to the `safety_checker`.
696
+ """
697
+ # 1. Check inputs. Raise error if not correct
698
+ self.check_inputs(
699
+ prompt,
700
+ callback_steps,
701
+ negative_prompt,
702
+ prompt_embeds,
703
+ negative_prompt_embeds,
704
+ )
705
+
706
+ # 2. Define call parameters
707
+ height = height or self.unet.config.sample_size
708
+ width = width or self.unet.config.sample_size
709
+
710
+ if prompt is not None and isinstance(prompt, str):
711
+ batch_size = 1
712
+ elif prompt is not None and isinstance(prompt, list):
713
+ batch_size = len(prompt)
714
+ else:
715
+ batch_size = prompt_embeds.shape[0]
716
+
717
+ device = self._execution_device
718
+
719
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
720
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
721
+ # corresponds to doing no classifier free guidance.
722
+ do_classifier_free_guidance = guidance_scale > 1.0
723
+
724
+ # 3. Encode input prompt
725
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
726
+ prompt,
727
+ do_classifier_free_guidance,
728
+ num_images_per_prompt=num_images_per_prompt,
729
+ device=device,
730
+ negative_prompt=negative_prompt,
731
+ prompt_embeds=prompt_embeds,
732
+ negative_prompt_embeds=negative_prompt_embeds,
733
+ clean_caption=clean_caption,
734
+ )
735
+
736
+ if do_classifier_free_guidance:
737
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
738
+
739
+ # 4. Prepare timesteps
740
+ if timesteps is not None:
741
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
742
+ timesteps = self.scheduler.timesteps
743
+ num_inference_steps = len(timesteps)
744
+ else:
745
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
746
+ timesteps = self.scheduler.timesteps
747
+
748
+ # 5. Prepare intermediate images
749
+ intermediate_images = self.prepare_intermediate_images(
750
+ batch_size * num_images_per_prompt,
751
+ self.unet.config.in_channels,
752
+ num_frames,
753
+ height,
754
+ width,
755
+ prompt_embeds.dtype,
756
+ device,
757
+ generator,
758
+ )
759
+
760
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
761
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
762
+
763
+ # HACK: see comment in `enable_model_cpu_offload`
764
+ if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
765
+ self.text_encoder_offload_hook.offload()
766
+
767
+ # 7. Denoising loop
768
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
769
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
770
+ for i, t in enumerate(timesteps):
771
+ model_input = (
772
+ torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
773
+ )
774
+ model_input = self.scheduler.scale_model_input(model_input, t)
775
+
776
+ # predict the noise residual
777
+ noise_pred = self.unet(
778
+ model_input,
779
+ t,
780
+ encoder_hidden_states=prompt_embeds,
781
+ cross_attention_kwargs=cross_attention_kwargs,
782
+ ).sample
783
+
784
+ # perform guidance
785
+ if do_classifier_free_guidance:
786
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
787
+ noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
788
+ noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
789
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
790
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
791
+
792
+ if self.scheduler.config.variance_type not in [
793
+ "learned",
794
+ "learned_range",
795
+ ]:
796
+ noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
797
+
798
+ # reshape latents
799
+ bsz, channel, frames, height, width = intermediate_images.shape
800
+ intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(
801
+ bsz * frames, channel, height, width
802
+ )
803
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
804
+
805
+ # compute the previous noisy sample x_t -> x_t-1
806
+ intermediate_images = self.scheduler.step(
807
+ noise_pred, t, intermediate_images, **extra_step_kwargs
808
+ ).prev_sample
809
+
810
+ # reshape latents back
811
+ intermediate_images = (
812
+ intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
813
+ )
814
+
815
+ # call the callback, if provided
816
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
817
+ progress_bar.update()
818
+ if callback is not None and i % callback_steps == 0:
819
+ callback(i, t, intermediate_images)
820
+
821
+ video_tensor = intermediate_images
822
+
823
+ if output_type == "pt":
824
+ video = video_tensor
825
+ else:
826
+ video = tensor2vid(video_tensor)
827
+
828
+ # Offload last model to CPU
829
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
830
+ self.final_offload_hook.offload()
831
+
832
+ if not return_dict:
833
+ return (video,)
834
+
835
+ return TextToVideoPipelineOutput(frames=video)