Spaces:
Runtime error
Runtime error
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| from dataclasses import dataclass | |
| from typing import Callable, List, Optional, Union | |
| import numpy as np | |
| from diffusers.utils import deprecate, logging, BaseOutput | |
| from einops import rearrange, repeat | |
| from torch.nn.functional import grid_sample | |
| import torchvision.transforms as T | |
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | |
| from diffusers.models import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.schedulers import KarrasDiffusionSchedulers | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| import PIL | |
| from PIL import Image | |
| from kornia.morphology import dilation | |
| class TextToVideoPipelineOutput(BaseOutput): | |
| # videos: Union[torch.Tensor, np.ndarray] | |
| # code: Union[torch.Tensor, np.ndarray] | |
| images: Union[List[PIL.Image.Image], np.ndarray] | |
| nsfw_content_detected: Optional[List[bool]] | |
| def coords_grid(batch, ht, wd, device): | |
| # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py | |
| coords = torch.meshgrid(torch.arange( | |
| ht, device=device), torch.arange(wd, device=device)) | |
| coords = torch.stack(coords[::-1], dim=0).float() | |
| return coords[None].repeat(batch, 1, 1, 1) | |
| class TextToVideoPipeline(StableDiffusionPipeline): | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler: KarrasDiffusionSchedulers, | |
| safety_checker: StableDiffusionSafetyChecker, | |
| feature_extractor: CLIPFeatureExtractor, | |
| requires_safety_checker: bool = True, | |
| ): | |
| super().__init__(vae, text_encoder, tokenizer, unet, scheduler, | |
| safety_checker, feature_extractor, requires_safety_checker) | |
| def DDPM_forward(self, x0, t0, tMax, generator, device, shape, text_embeddings): | |
| rand_device = "cpu" if device.type == "mps" else device | |
| if x0 is None: | |
| return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device) | |
| else: | |
| eps = torch.randn(x0.shape, dtype=text_embeddings.dtype, generator=generator, | |
| device=rand_device) | |
| alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax]) | |
| xt = torch.sqrt(alpha_vec) * x0 + \ | |
| torch.sqrt(1-alpha_vec) * eps | |
| return xt | |
| def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): | |
| shape = (batch_size, num_channels_latents, video_length, height // | |
| self.vae_scale_factor, width // self.vae_scale_factor) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| if latents is None: | |
| rand_device = "cpu" if device.type == "mps" else device | |
| if isinstance(generator, list): | |
| shape = (1,) + shape[1:] | |
| latents = [ | |
| torch.randn( | |
| shape, generator=generator[i], device=rand_device, dtype=dtype) | |
| for i in range(batch_size) | |
| ] | |
| latents = torch.cat(latents, dim=0).to(device) | |
| else: | |
| latents = torch.randn( | |
| shape, generator=generator, device=rand_device, dtype=dtype).to(device) | |
| else: | |
| latents = latents.to(device) | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = latents * self.scheduler.init_noise_sigma | |
| return latents | |
| def warp_latents_independently(self, latents, reference_flow): | |
| _, _, H, W = reference_flow.size() | |
| b, _, f, h, w = latents.size() | |
| assert b == 1 | |
| coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype) | |
| coords_t0 = coords0 + reference_flow | |
| coords_t0[:, 0] /= W | |
| coords_t0[:, 1] /= H | |
| coords_t0 = coords_t0 * 2.0 - 1.0 | |
| coords_t0 = T.Resize((h, w))(coords_t0) | |
| coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c') | |
| latents_0 = rearrange(latents[0], 'c f h w -> f c h w') | |
| warped = grid_sample(latents_0, coords_t0, | |
| mode='nearest', padding_mode='reflection') | |
| warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f) | |
| return warped | |
| def DDIM_backward(self, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, null_embs, text_embeddings, latents_local, | |
| latents_dtype, guidance_scale, guidance_stop_step, callback, callback_steps, extra_step_kwargs, num_warmup_steps): | |
| entered = False | |
| f = latents_local.shape[2] | |
| latents_local = rearrange(latents_local, "b c f w h -> (b f) c w h") | |
| latents = latents_local.detach().clone() | |
| x_t0_1 = None | |
| x_t1_1 = None | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if t > skip_t: | |
| continue | |
| else: | |
| if not entered: | |
| print( | |
| f"Continue DDIM with i = {i}, t = {t}, latent = {latents.shape}, device = {latents.device}, type = {latents.dtype}") | |
| entered = True | |
| latents = latents.detach() | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat( | |
| [latents] * 2) if do_classifier_free_guidance else latents | |
| latent_model_input = self.scheduler.scale_model_input( | |
| latent_model_input, t) | |
| # predict the noise residual | |
| with torch.no_grad(): | |
| if null_embs is not None: | |
| text_embeddings[0] = null_embs[i][0] | |
| te = torch.cat([repeat(text_embeddings[0, :, :], "c k -> f c k", f=f), | |
| repeat(text_embeddings[1, :, :], "c k -> f c k", f=f)]) | |
| noise_pred = self.unet( | |
| latent_model_input, t, encoder_hidden_states=te).sample.to(dtype=latents_dtype) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk( | |
| 2) | |
| noise_pred = noise_pred_uncond + guidance_scale * \ | |
| (noise_pred_text - noise_pred_uncond) | |
| if i >= guidance_stop_step * len(timesteps): | |
| alpha = 0 | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
| # latents = latents - alpha * grads / (torch.norm(grads) + 1e-10) | |
| # call the callback, if provided | |
| if i < len(timesteps)-1 and timesteps[i+1] == t0: | |
| x_t0_1 = latents.detach().clone() | |
| print(f"latent t0 found at i = {i}, t = {t}") | |
| elif i < len(timesteps)-1 and timesteps[i+1] == t1: | |
| x_t1_1 = latents.detach().clone() | |
| print(f"latent t1 found at i={i}, t = {t}") | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| callback(i, t, latents) | |
| latents = rearrange(latents, "(b f) c w h -> b c f w h", f=f) | |
| res = {"x0": latents.detach().clone()} | |
| if x_t0_1 is not None: | |
| x_t0_1 = rearrange(x_t0_1, "(b f) c w h -> b c f w h", f=f) | |
| res["x_t0_1"] = x_t0_1.detach().clone() | |
| if x_t1_1 is not None: | |
| x_t1_1 = rearrange(x_t1_1, "(b f) c w h -> b c f w h", f=f) | |
| res["x_t1_1"] = x_t1_1.detach().clone() | |
| return res | |
| def decode_latents(self, latents): | |
| video_length = latents.shape[2] | |
| latents = 1 / 0.18215 * latents | |
| latents = rearrange(latents, "b c f h w -> (b f) c h w") | |
| video = self.vae.decode(latents).sample | |
| video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) | |
| video = (video / 2 + 0.5).clamp(0, 1) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | |
| video = video.detach().cpu() | |
| return video | |
| def create_motion_field(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents): | |
| reference_flow = torch.zeros( | |
| (video_length-1, 2, 512, 512), device=latents.device, dtype=latents.dtype) | |
| for fr_idx, frame_id in enumerate(frame_ids): | |
| reference_flow[fr_idx, 0, :, | |
| :] = motion_field_strength_x*(frame_id) | |
| reference_flow[fr_idx, 1, :, | |
| :] = motion_field_strength_y*(frame_id) | |
| return reference_flow | |
| def create_motion_field_and_warp_latents(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents): | |
| motion_field = self.create_motion_field(motion_field_strength_x=motion_field_strength_x, | |
| motion_field_strength_y=motion_field_strength_y, latents=latents, video_length=video_length, frame_ids=frame_ids) | |
| for idx, latent in enumerate(latents): | |
| latents[idx] = self.warp_latents_independently( | |
| latent[None], motion_field) | |
| return motion_field, latents | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]], | |
| video_length: Optional[int], | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 7.5, | |
| guidance_stop_step: float = 0.5, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| eta: float = 0.0, | |
| generator: Optional[Union[torch.Generator, | |
| List[torch.Generator]]] = None, | |
| xT: Optional[torch.FloatTensor] = None, | |
| null_embs: Optional[torch.FloatTensor] = None, | |
| motion_field_strength_x: float = 12, | |
| motion_field_strength_y: float = 12, | |
| output_type: Optional[str] = "tensor", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[ | |
| int, int, torch.FloatTensor], None]] = None, | |
| callback_steps: Optional[int] = 1, | |
| use_motion_field: bool = True, | |
| smooth_bg: bool = False, | |
| smooth_bg_strength: float = 0.4, | |
| t0: int = 44, | |
| t1: int = 47, | |
| **kwargs, | |
| ): | |
| frame_ids = kwargs.pop("frame_ids", list(range(video_length))) | |
| assert t0 < t1 | |
| assert num_videos_per_prompt == 1 | |
| assert isinstance(prompt, list) and len(prompt) > 0 | |
| assert isinstance(negative_prompt, list) or negative_prompt is None | |
| prompt_types = [prompt, negative_prompt] | |
| for idx, prompt_type in enumerate(prompt_types): | |
| prompt_template = None | |
| for prompt in prompt_type: | |
| if prompt_template is None: | |
| prompt_template = prompt | |
| else: | |
| assert prompt == prompt_template | |
| if prompt_types[idx] is not None: | |
| prompt_types[idx] = prompt_types[idx][0] | |
| prompt = prompt_types[0] | |
| negative_prompt = prompt_types[1] | |
| # Default height and width to unet | |
| height = height or self.unet.config.sample_size * self.vae_scale_factor | |
| width = width or self.unet.config.sample_size * self.vae_scale_factor | |
| # Check inputs. Raise error if not correct | |
| self.check_inputs(prompt, height, width, callback_steps) | |
| # Define call parameters | |
| batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| # Encode input prompt | |
| text_embeddings = self._encode_prompt( | |
| prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt | |
| ) | |
| # Prepare timesteps | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| timesteps = self.scheduler.timesteps | |
| # print(f" Latent shape = {latents.shape}") | |
| # Prepare latent variables | |
| num_channels_latents = self.unet.in_channels | |
| xT = self.prepare_latents( | |
| batch_size * num_videos_per_prompt, | |
| num_channels_latents, | |
| 1, | |
| height, | |
| width, | |
| text_embeddings.dtype, | |
| device, | |
| generator, | |
| xT, | |
| ) | |
| dtype = xT.dtype | |
| # when motion field is not used, augment with random latent codes | |
| if use_motion_field: | |
| xT = xT[:, :, :1] | |
| else: | |
| if xT.shape[2] < video_length: | |
| xT_missing = self.prepare_latents( | |
| batch_size * num_videos_per_prompt, | |
| num_channels_latents, | |
| video_length-xT.shape[2], | |
| height, | |
| width, | |
| text_embeddings.dtype, | |
| device, | |
| generator, | |
| None, | |
| ) | |
| xT = torch.cat([xT, xT_missing], dim=2) | |
| xInit = xT.clone() | |
| timesteps_ddpm = [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, | |
| 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, | |
| 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, | |
| 141, 121, 101, 81, 61, 41, 21, 1] | |
| timesteps_ddpm.reverse() | |
| t0 = timesteps_ddpm[t0] | |
| t1 = timesteps_ddpm[t1] | |
| print(f"t0 = {t0} t1 = {t1}") | |
| x_t1_1 = None | |
| # Prepare extra step kwargs. | |
| extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
| # Denoising loop | |
| num_warmup_steps = len(timesteps) - \ | |
| num_inference_steps * self.scheduler.order | |
| shape = (batch_size, num_channels_latents, 1, height // | |
| self.vae_scale_factor, width // self.vae_scale_factor) | |
| ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance, | |
| null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, | |
| callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) | |
| x0 = ddim_res["x0"].detach() | |
| if "x_t0_1" in ddim_res: | |
| x_t0_1 = ddim_res["x_t0_1"].detach() | |
| if "x_t1_1" in ddim_res: | |
| x_t1_1 = ddim_res["x_t1_1"].detach() | |
| del ddim_res | |
| del xT | |
| if use_motion_field: | |
| del x0 | |
| x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1) | |
| reference_flow, x_t0_k = self.create_motion_field_and_warp_latents( | |
| motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_t0_k, video_length=video_length, frame_ids=frame_ids[1:]) | |
| # assuming t0=t1=1000, if t0 = 1000 | |
| if t1 > t0: | |
| x_t1_k = self.DDPM_forward( | |
| x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) | |
| else: | |
| x_t1_k = x_t0_k | |
| if x_t1_1 is None: | |
| raise Exception | |
| x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2).clone().detach() | |
| ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, | |
| null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, | |
| guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) | |
| x0 = ddim_res["x0"].detach() | |
| del ddim_res | |
| del x_t1 | |
| del x_t1_1 | |
| del x_t1_k | |
| else: | |
| x_t1 = x_t1_1.clone() | |
| x_t1_1 = x_t1_1[:, :, :1, :, :].clone() | |
| x_t1_k = x_t1_1[:, :, 1:, :, :].clone() | |
| x_t0_k = x_t0_1[:, :, 1:, :, :].clone() | |
| x_t0_1 = x_t0_1[:, :, :1, :, :].clone() | |
| # smooth background | |
| if smooth_bg: | |
| h, w = x0.shape[3], x0.shape[4] | |
| M_FG = torch.zeros((batch_size, video_length, h, w), | |
| device=x0.device).to(x0.dtype) | |
| for batch_idx, x0_b in enumerate(x0): | |
| z0_b = self.decode_latents(x0_b[None]).detach() | |
| z0_b = rearrange(z0_b[0], "c f h w -> f h w c") | |
| for frame_idx, z0_f in enumerate(z0_b): | |
| z0_f = torch.round( | |
| z0_f * 255).cpu().numpy().astype(np.uint8) | |
| # apply SOD detection | |
| m_f = torch.tensor(self.sod_model.process_data( | |
| z0_f), device=x0.device).to(x0.dtype) | |
| mask = T.Resize( | |
| size=(h, w), interpolation=T.InterpolationMode.NEAREST)(m_f[None]) | |
| kernel = torch.ones(5, 5, device=x0.device, dtype=x0.dtype) | |
| mask = dilation(mask[None].to(x0.device), kernel)[0] | |
| M_FG[batch_idx, frame_idx, :, :] = mask | |
| x_t1_1_fg_masked = x_t1_1 * \ | |
| (1 - repeat(M_FG[:, 0, :, :], | |
| "b w h -> b c 1 w h", c=x_t1_1.shape[1])) | |
| x_t1_1_fg_masked_moved = [] | |
| for batch_idx, x_t1_1_fg_masked_b in enumerate(x_t1_1_fg_masked): | |
| x_t1_fg_masked_b = x_t1_1_fg_masked_b.clone() | |
| x_t1_fg_masked_b = x_t1_fg_masked_b.repeat( | |
| 1, video_length-1, 1, 1) | |
| if use_motion_field: | |
| x_t1_fg_masked_b = x_t1_fg_masked_b[None] | |
| x_t1_fg_masked_b = self.warp_latents_independently( | |
| x_t1_fg_masked_b, reference_flow) | |
| else: | |
| x_t1_fg_masked_b = x_t1_fg_masked_b[None] | |
| x_t1_fg_masked_b = torch.cat( | |
| [x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2) | |
| x_t1_1_fg_masked_moved.append(x_t1_fg_masked_b) | |
| x_t1_1_fg_masked_moved = torch.cat(x_t1_1_fg_masked_moved, dim=0) | |
| M_FG_1 = M_FG[:, :1, :, :] | |
| M_FG_warped = [] | |
| for batch_idx, m_fg_1_b in enumerate(M_FG_1): | |
| m_fg_1_b = m_fg_1_b[None, None] | |
| m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1) | |
| if use_motion_field: | |
| m_fg_b = self.warp_latents_independently( | |
| m_fg_b.clone(), reference_flow) | |
| M_FG_warped.append( | |
| torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1)) | |
| M_FG_warped = torch.cat(M_FG_warped, dim=0) | |
| channels = x0.shape[1] | |
| M_BG = (1-M_FG) * (1 - M_FG_warped) | |
| M_BG = repeat(M_BG, "b f h w -> b c f h w", c=channels) | |
| a_convex = smooth_bg_strength | |
| latents = (1-M_BG) * x_t1 + M_BG * (a_convex * | |
| x_t1 + (1-a_convex) * x_t1_1_fg_masked_moved) | |
| ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, | |
| null_embs=null_embs, text_embeddings=text_embeddings, latents_local=latents, latents_dtype=dtype, guidance_scale=guidance_scale, | |
| guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) | |
| x0 = ddim_res["x0"].detach() | |
| del ddim_res | |
| del latents | |
| latents = x0 | |
| # manually for max memory savings | |
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | |
| self.unet.to("cpu") | |
| torch.cuda.empty_cache() | |
| if output_type == "latent": | |
| image = latents | |
| has_nsfw_concept = None | |
| else: | |
| image = self.decode_latents(latents) | |
| # Run safety checker | |
| image, has_nsfw_concept = self.run_safety_checker( | |
| image, device, text_embeddings.dtype) | |
| image = rearrange(image, "b c f h w -> (b f) h w c") | |
| # Offload last model to CPU | |
| if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | |
| self.final_offload_hook.offload() | |
| if not return_dict: | |
| return (image, has_nsfw_concept) | |
| return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | |