bluestarburst commited on
Commit
8dc8329
1 Parent(s): 09bf9a3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +428 -0
pipeline.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import numpy as np
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from diffusers.utils import is_accelerate_available
12
+ from packaging import version
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers.configuration_utils import FrozenDict
16
+ from diffusers.models import AutoencoderKL
17
+ from diffusers import DiffusionPipeline
18
+ from diffusers.schedulers import (
19
+ DDIMScheduler,
20
+ DPMSolverMultistepScheduler,
21
+ EulerAncestralDiscreteScheduler,
22
+ EulerDiscreteScheduler,
23
+ LMSDiscreteScheduler,
24
+ PNDMScheduler,
25
+ )
26
+ from diffusers.utils import deprecate, logging, BaseOutput
27
+
28
+ from einops import rearrange
29
+
30
+ from ..models.unet import UNet3DConditionModel
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class AnimationPipelineOutput(BaseOutput):
38
+ videos: Union[torch.Tensor, np.ndarray]
39
+
40
+
41
+ class AnimationPipeline(DiffusionPipeline):
42
+ _optional_components = []
43
+
44
+ def __init__(
45
+ self,
46
+ vae: AutoencoderKL,
47
+ text_encoder: CLIPTextModel,
48
+ tokenizer: CLIPTokenizer,
49
+ unet: UNet3DConditionModel,
50
+ scheduler: Union[
51
+ DDIMScheduler,
52
+ PNDMScheduler,
53
+ LMSDiscreteScheduler,
54
+ EulerDiscreteScheduler,
55
+ EulerAncestralDiscreteScheduler,
56
+ DPMSolverMultistepScheduler,
57
+ ],
58
+ ):
59
+ super().__init__()
60
+
61
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
62
+ deprecation_message = (
63
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
64
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
65
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
66
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
67
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
68
+ " file"
69
+ )
70
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
71
+ new_config = dict(scheduler.config)
72
+ new_config["steps_offset"] = 1
73
+ scheduler._internal_dict = FrozenDict(new_config)
74
+
75
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
76
+ deprecation_message = (
77
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
78
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
79
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
80
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
81
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
82
+ )
83
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
84
+ new_config = dict(scheduler.config)
85
+ new_config["clip_sample"] = False
86
+ scheduler._internal_dict = FrozenDict(new_config)
87
+
88
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
89
+ version.parse(unet.config._diffusers_version).base_version
90
+ ) < version.parse("0.9.0.dev0")
91
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
92
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
93
+ deprecation_message = (
94
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
95
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
96
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
97
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
98
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
99
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
100
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
101
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
102
+ " the `unet/config.json` file"
103
+ )
104
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
105
+ new_config = dict(unet.config)
106
+ new_config["sample_size"] = 64
107
+ unet._internal_dict = FrozenDict(new_config)
108
+
109
+ self.register_modules(
110
+ vae=vae,
111
+ text_encoder=text_encoder,
112
+ tokenizer=tokenizer,
113
+ unet=unet,
114
+ scheduler=scheduler,
115
+ )
116
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
117
+
118
+ def enable_vae_slicing(self):
119
+ self.vae.enable_slicing()
120
+
121
+ def disable_vae_slicing(self):
122
+ self.vae.disable_slicing()
123
+
124
+ def enable_sequential_cpu_offload(self, gpu_id=0):
125
+ if is_accelerate_available():
126
+ from accelerate import cpu_offload
127
+ else:
128
+ raise ImportError("Please install accelerate via `pip install accelerate`")
129
+
130
+ device = torch.device(f"cuda:{gpu_id}")
131
+
132
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
133
+ if cpu_offloaded_model is not None:
134
+ cpu_offload(cpu_offloaded_model, device)
135
+
136
+
137
+ @property
138
+ def _execution_device(self):
139
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
140
+ return self.device
141
+ for module in self.unet.modules():
142
+ if (
143
+ hasattr(module, "_hf_hook")
144
+ and hasattr(module._hf_hook, "execution_device")
145
+ and module._hf_hook.execution_device is not None
146
+ ):
147
+ return torch.device(module._hf_hook.execution_device)
148
+ return self.device
149
+
150
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
151
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
152
+
153
+ text_inputs = self.tokenizer(
154
+ prompt,
155
+ padding="max_length",
156
+ max_length=self.tokenizer.model_max_length,
157
+ truncation=True,
158
+ return_tensors="pt",
159
+ )
160
+ text_input_ids = text_inputs.input_ids
161
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
162
+
163
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
164
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
165
+ logger.warning(
166
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
167
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
168
+ )
169
+
170
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
171
+ attention_mask = text_inputs.attention_mask.to(device)
172
+ else:
173
+ attention_mask = None
174
+
175
+ text_embeddings = self.text_encoder(
176
+ text_input_ids.to(device),
177
+ attention_mask=attention_mask,
178
+ )
179
+ text_embeddings = text_embeddings[0]
180
+
181
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
182
+ bs_embed, seq_len, _ = text_embeddings.shape
183
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
184
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
185
+
186
+ # get unconditional embeddings for classifier free guidance
187
+ if do_classifier_free_guidance:
188
+ uncond_tokens: List[str]
189
+ if negative_prompt is None:
190
+ uncond_tokens = [""] * batch_size
191
+ elif type(prompt) is not type(negative_prompt):
192
+ raise TypeError(
193
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
194
+ f" {type(prompt)}."
195
+ )
196
+ elif isinstance(negative_prompt, str):
197
+ uncond_tokens = [negative_prompt]
198
+ elif batch_size != len(negative_prompt):
199
+ raise ValueError(
200
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
201
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
202
+ " the batch size of `prompt`."
203
+ )
204
+ else:
205
+ uncond_tokens = negative_prompt
206
+
207
+ max_length = text_input_ids.shape[-1]
208
+ uncond_input = self.tokenizer(
209
+ uncond_tokens,
210
+ padding="max_length",
211
+ max_length=max_length,
212
+ truncation=True,
213
+ return_tensors="pt",
214
+ )
215
+
216
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
217
+ attention_mask = uncond_input.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ uncond_embeddings = self.text_encoder(
222
+ uncond_input.input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ uncond_embeddings = uncond_embeddings[0]
226
+
227
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
228
+ seq_len = uncond_embeddings.shape[1]
229
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
231
+
232
+ # For classifier free guidance, we need to do two forward passes.
233
+ # Here we concatenate the unconditional and text embeddings into a single batch
234
+ # to avoid doing two forward passes
235
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
236
+
237
+ return text_embeddings
238
+
239
+ def decode_latents(self, latents):
240
+ video_length = latents.shape[2]
241
+ latents = 1 / 0.18215 * latents
242
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
243
+ # video = self.vae.decode(latents).sample
244
+ video = []
245
+ for frame_idx in tqdm(range(latents.shape[0])):
246
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
247
+ video = torch.cat(video)
248
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
249
+ video = (video / 2 + 0.5).clamp(0, 1)
250
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
251
+ video = video.cpu().float().numpy()
252
+ return video
253
+
254
+ def prepare_extra_step_kwargs(self, generator, eta):
255
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
256
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
257
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
258
+ # and should be between [0, 1]
259
+
260
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
261
+ extra_step_kwargs = {}
262
+ if accepts_eta:
263
+ extra_step_kwargs["eta"] = eta
264
+
265
+ # check if the scheduler accepts generator
266
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
267
+ if accepts_generator:
268
+ extra_step_kwargs["generator"] = generator
269
+ return extra_step_kwargs
270
+
271
+ def check_inputs(self, prompt, height, width, callback_steps):
272
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
273
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
274
+
275
+ if height % 8 != 0 or width % 8 != 0:
276
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
277
+
278
+ if (callback_steps is None) or (
279
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
280
+ ):
281
+ raise ValueError(
282
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
283
+ f" {type(callback_steps)}."
284
+ )
285
+
286
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
287
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
288
+ if isinstance(generator, list) and len(generator) != batch_size:
289
+ raise ValueError(
290
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
291
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
292
+ )
293
+ if latents is None:
294
+ rand_device = "cpu" if device.type == "mps" else device
295
+
296
+ if isinstance(generator, list):
297
+ shape = shape
298
+ # shape = (1,) + shape[1:]
299
+ latents = [
300
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
301
+ for i in range(batch_size)
302
+ ]
303
+ latents = torch.cat(latents, dim=0).to(device)
304
+ else:
305
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
306
+ else:
307
+ if latents.shape != shape:
308
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
309
+ latents = latents.to(device)
310
+
311
+ # scale the initial noise by the standard deviation required by the scheduler
312
+ latents = latents * self.scheduler.init_noise_sigma
313
+ return latents
314
+
315
+ @torch.no_grad()
316
+ def __call__(
317
+ self,
318
+ prompt: Union[str, List[str]],
319
+ video_length: Optional[int],
320
+ height: Optional[int] = None,
321
+ width: Optional[int] = None,
322
+ num_inference_steps: int = 50,
323
+ guidance_scale: float = 7.5,
324
+ negative_prompt: Optional[Union[str, List[str]]] = None,
325
+ num_videos_per_prompt: Optional[int] = 1,
326
+ eta: float = 0.0,
327
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
328
+ latents: Optional[torch.FloatTensor] = None,
329
+ output_type: Optional[str] = "tensor",
330
+ return_dict: bool = True,
331
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
332
+ callback_steps: Optional[int] = 1,
333
+ **kwargs,
334
+ ):
335
+ # Default height and width to unet
336
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
337
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
338
+
339
+ # Check inputs. Raise error if not correct
340
+ self.check_inputs(prompt, height, width, callback_steps)
341
+
342
+ # Define call parameters
343
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
344
+ batch_size = 1
345
+ if latents is not None:
346
+ batch_size = latents.shape[0]
347
+ if isinstance(prompt, list):
348
+ batch_size = len(prompt)
349
+
350
+ device = self._execution_device
351
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
352
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
353
+ # corresponds to doing no classifier free guidance.
354
+ do_classifier_free_guidance = guidance_scale > 1.0
355
+
356
+ # Encode input prompt
357
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
358
+ if negative_prompt is not None:
359
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
360
+ text_embeddings = self._encode_prompt(
361
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
362
+ )
363
+
364
+ # Prepare timesteps
365
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
366
+ timesteps = self.scheduler.timesteps
367
+
368
+ # Prepare latent variables
369
+ num_channels_latents = self.unet.in_channels
370
+ latents = self.prepare_latents(
371
+ batch_size * num_videos_per_prompt,
372
+ num_channels_latents,
373
+ video_length,
374
+ height,
375
+ width,
376
+ text_embeddings.dtype,
377
+ device,
378
+ generator,
379
+ latents,
380
+ )
381
+ latents_dtype = latents.dtype
382
+
383
+ # Prepare extra step kwargs.
384
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
385
+
386
+ # Denoising loop
387
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
388
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
389
+ for i, t in enumerate(timesteps):
390
+ # expand the latents if we are doing classifier free guidance
391
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
392
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
393
+
394
+ # predict the noise residual
395
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
396
+ # noise_pred = []
397
+ # import pdb
398
+ # pdb.set_trace()
399
+ # for batch_idx in range(latent_model_input.shape[0]):
400
+ # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
401
+ # noise_pred.append(noise_pred_single)
402
+ # noise_pred = torch.cat(noise_pred)
403
+
404
+ # perform guidance
405
+ if do_classifier_free_guidance:
406
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
407
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
408
+
409
+ # compute the previous noisy sample x_t -> x_t-1
410
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
411
+
412
+ # call the callback, if provided
413
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
414
+ progress_bar.update()
415
+ if callback is not None and i % callback_steps == 0:
416
+ callback(i, t, latents)
417
+
418
+ # Post-processing
419
+ video = self.decode_latents(latents)
420
+
421
+ # Convert to tensor
422
+ if output_type == "tensor":
423
+ video = torch.from_numpy(video)
424
+
425
+ if not return_dict:
426
+ return video
427
+
428
+ return AnimationPipelineOutput(videos=video)