from diffusers import StableDiffusionPipeline import torch from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np from diffusers.utils import deprecate, logging, BaseOutput from einops import rearrange, repeat from torch.nn.functional import grid_sample import torchvision.transforms as T from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker import PIL from PIL import Image from kornia.morphology import dilation @dataclass class TextToVideoPipelineOutput(BaseOutput): # videos: Union[torch.Tensor, np.ndarray] # code: Union[torch.Tensor, np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray] nsfw_content_detected: Optional[List[bool]] def coords_grid(batch, ht, wd, device): # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py coords = torch.meshgrid(torch.arange( ht, device=device), torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) class TextToVideoPipeline(StableDiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, ): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) def DDPM_forward(self, x0, t0, tMax, generator, device, shape, text_embeddings): rand_device = "cpu" if device.type == "mps" else device if x0 is None: return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device) else: eps = torch.randn_like(x0, dtype=text_embeddings.dtype).to(device) alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax]) xt = torch.sqrt(alpha_vec) * x0 + \ torch.sqrt(1-alpha_vec) * eps return xt def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: rand_device = "cpu" if device.type == "mps" else device if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn( shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn( shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def warp_latents_independently(self, latents, reference_flow, inject_noise=False): _, _, H, W = reference_flow.size() b, _, f, h, w = latents.size() assert b == 1 coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = T.Resize((h, w))(coords_t0) coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c') latents_0 = rearrange(latents[0], 'c f h w -> f c h w') warped = grid_sample(latents_0, coords_t0, mode='nearest', padding_mode='reflection') if inject_noise: idx = torch.logical_or(coords_t0 >= 1, coords_t0 < -1) reset_noise = torch.randn(idx.shape) idx = torch.logical_or(idx[:, :, :, 0], idx[:, :, :, 1]) idx = repeat(idx, "f w h -> f c w h", c=warped.shape[1]) reset_noise = torch.randn( size=warped.shape, dtype=warped.dtype, device=warped.device) warped[idx] = reset_noise[idx] warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f) return warped def DDIM_backward(self, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, null_embs, text_embeddings, latents_local, latents_dtype, guidance_scale, guidance_stop_step, callback, callback_steps, extra_step_kwargs, num_warmup_steps): entered = False f = latents_local.shape[2] latents_local = rearrange(latents_local, "b c f w h -> (b f) c w h") latents = latents_local.detach().clone() x_t0_1 = None x_t1_1 = None with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if t > skip_t: continue else: if not entered: print( f"Continue DDIM with i = {i}, t = {t}, latent = {latents.shape}, device = {latents.device}, type = {latents.dtype}") entered = True latents = latents.detach() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) # predict the noise residual with torch.no_grad(): if null_embs is not None: text_embeddings[0] = null_embs[i][0] te = torch.cat([repeat(text_embeddings[0, :, :], "c k -> f c k", f=f), repeat(text_embeddings[1, :, :], "c k -> f c k", f=f)]) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=te).sample.to(dtype=latents_dtype) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk( 2) noise_pred = noise_pred_uncond + guidance_scale * \ (noise_pred_text - noise_pred_uncond) if i >= guidance_stop_step * len(timesteps): alpha = 0 # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs).prev_sample # latents = latents - alpha * grads / (torch.norm(grads) + 1e-10) # call the callback, if provided if i < len(timesteps)-1 and timesteps[i+1] == t0: x_t0_1 = latents.detach().clone() print(f"latent t0 found at i = {i}, t = {t}") elif i < len(timesteps)-1 and timesteps[i+1] == t1: x_t1_1 = latents.detach().clone() print(f"latent t1 found at i={i}, t = {t}") if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = rearrange(latents, "(b f) c w h -> b c f w h", f=f) res = {"x0": latents.detach().clone()} if x_t0_1 is not None: x_t0_1 = rearrange(x_t0_1, "(b f) c w h -> b c f w h", f=f) res["x_t0_1"] = x_t0_1.detach().clone() if x_t1_1 is not None: x_t1_1 = rearrange(x_t1_1, "(b f) c w h -> b c f w h", f=f) res["x_t1_1"] = x_t1_1.detach().clone() return res def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") video = self.vae.decode(latents).sample video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 video = video.detach().cpu() return video def create_motion_field(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents): reference_flow = torch.zeros( (video_length-1, 2, 512, 512), device=latents.device, dtype=latents.dtype) for fr_idx in range(video_length-1): reference_flow[fr_idx, 0, :, :] = motion_field_strength_x*(frame_ids[fr_idx]+1) reference_flow[fr_idx, 1, :, :] = motion_field_strength_y*(frame_ids[fr_idx]+1) return reference_flow def create_motion_field_and_warp_latents(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, inject_noise_to_warp, latents): motion_field = self.create_motion_field(motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=latents, video_length=video_length, frame_ids=frame_ids) for idx, latent in enumerate(latents): latents[idx] = self.warp_latents_independently( latent[None], motion_field, inject_noise=inject_noise_to_warp) return motion_field, latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, guidance_stop_step: float = 0.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, xT: Optional[torch.FloatTensor] = None, null_embs: Optional[torch.FloatTensor] = None, motion_field_strength_x: float = 12, motion_field_strength_y: float = 12, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[ int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, use_motion_field: bool = True, smooth_bg: bool = False, smooth_bg_strength: float = 0.4, inject_noise_to_warp: bool = False, t0: int = 44, t1: int = 47, **kwargs, ): frame_ids = kwargs.pop("frame_ids", list(range(video_length))) assert num_videos_per_prompt == 1 assert isinstance(prompt, list) and len(prompt) > 0 assert isinstance(negative_prompt, list) or negative_prompt is None prompt_types = [prompt, negative_prompt] for idx, prompt_type in enumerate(prompt_types): prompt_template = None for prompt in prompt_type: if prompt_template is None: prompt_template = prompt else: assert prompt == prompt_template if prompt_types[idx] is not None: prompt_types[idx] = prompt_types[idx][0] prompt = prompt_types[0] negative_prompt = prompt_types[1] print( f" Motion field strength x = {motion_field_strength_x}, y = {motion_field_strength_y}") print(f" Use: Motion field = {use_motion_field}") print(f" Use: Background smoothing = {smooth_bg}") print(f"Inject noise to warp = {inject_noise_to_warp}") # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt text_embeddings = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # print(f" Latent shape = {latents.shape}") # Prepare latent variables num_channels_latents = self.unet.in_channels xT = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, 1, height, width, text_embeddings.dtype, device, generator, xT, ) dtype = xT.dtype # when motion field is not used, augment with random latent codes if use_motion_field: xT = xT[:, :, :1] else: if xT.shape[2] < video_length: xT_missing = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, video_length-xT.shape[2], height, width, text_embeddings.dtype, device, generator, None, ) xT = torch.cat([xT, xT_missing], dim=2) xInit = xT.clone() timesteps_ddpm = [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 1] timesteps_ddpm.reverse() t0 = timesteps_ddpm[t0] t1 = timesteps_ddpm[t1] print(f"t0 = {t0} t1 = {t1}") x_t1_1 = None # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Denoising loop num_warmup_steps = len(timesteps) - \ num_inference_steps * self.scheduler.order shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) if inject_noise_to_warp and use_motion_field: # if we inject to noise to warp function, we do it for timesteps T = 1000 x_t0_k = xT[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1) # reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, # frame_ids=frame_ids,video_length=video_length,inject_noise_to_warp=inject_noise_to_warp,latents = x_t0_k) # xT =torch.cat([xT, x_t0_k], dim=2).clone().detach() ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() if "x_t0_1" in ddim_res: x_t0_1 = ddim_res["x_t0_1"].detach() if "x_t1_1" in ddim_res: x_t1_1 = ddim_res["x_t1_1"].detach() del ddim_res del xT if inject_noise_to_warp and use_motion_field: # DDPM forward to allow for more motion if t1 > t0: x_t1_k = self.DDPM_forward( x0=x_t0_1, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) else: x_t1_k = x_t0_k if x_t1_1 is None: raise Exception x_t1 = x_t1_k.clone().detach() ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() del ddim_res del x_t1 del x_t1_k if use_motion_field and not inject_noise_to_warp: del x0 x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1) reference_flow, x_t0_k = self.create_motion_field_and_warp_latents( motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_t0_k, video_length=video_length, inject_noise_to_warp=inject_noise_to_warp, frame_ids=frame_ids) # assuming t0=t1=1000, if t0 = 1000 if t1 > t0: x_t1_k = self.DDPM_forward( x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) else: x_t1_k = x_t0_k if x_t1_1 is None: raise Exception x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2).clone().detach() ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() del ddim_res del x_t1 del x_t1_1 del x_t1_k else: x_t1 = x_t1_1.clone() x_t1_1 = x_t1_1[:, :, :1, :, :].clone() x_t1_k = x_t1_1[:, :, 1:, :, :].clone() x_t0_k = x_t0_1[:, :, 1:, :, :].clone() x_t0_1 = x_t0_1[:, :, :1, :, :].clone() # smooth background if smooth_bg: h, w = x0.shape[3], x0.shape[4] M_FG = torch.zeros((batch_size, video_length, h, w), device=x0.device).to(x0.dtype) for batch_idx, x0_b in enumerate(x0): z0_b = self.decode_latents(x0_b[None]).detach() z0_b = rearrange(z0_b[0], "c f h w -> f h w c") for frame_idx, z0_f in enumerate(z0_b): z0_f = torch.round( z0_f * 255).cpu().numpy().astype(np.uint8) # apply SOD detection m_f = torch.tensor(self.sod_model.process_data( z0_f), device=x0.device).to(x0.dtype) mask = T.Resize( size=(h, w), interpolation=T.InterpolationMode.NEAREST)(m_f[None]) kernel = torch.ones(5, 5, device=x0.device, dtype=x0.dtype) mask = dilation(mask[None].to(x0.device), kernel)[0] M_FG[batch_idx, frame_idx, :, :] = mask x_t1_1_fg_masked = x_t1_1 * \ (1 - repeat(M_FG[:, 0, :, :], "b w h -> b c 1 w h", c=x_t1_1.shape[1])) x_t1_1_fg_masked_moved = [] for batch_idx, x_t1_1_fg_masked_b in enumerate(x_t1_1_fg_masked): x_t1_fg_masked_b = x_t1_1_fg_masked_b.clone() x_t1_fg_masked_b = x_t1_fg_masked_b.repeat( 1, video_length-1, 1, 1) if use_motion_field: x_t1_fg_masked_b = x_t1_fg_masked_b[None] x_t1_fg_masked_b = self.warp_latents_independently( x_t1_fg_masked_b, reference_flow, inject_noise=False) else: x_t1_fg_masked_b = x_t1_fg_masked_b[None] x_t1_fg_masked_b = torch.cat( [x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2) x_t1_1_fg_masked_moved.append(x_t1_fg_masked_b) x_t1_1_fg_masked_moved = torch.cat(x_t1_1_fg_masked_moved, dim=0) M_FG_1 = M_FG[:, :1, :, :] M_FG_warped = [] for batch_idx, m_fg_1_b in enumerate(M_FG_1): m_fg_1_b = m_fg_1_b[None, None] m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1) if use_motion_field: m_fg_b = self.warp_latents_independently( m_fg_b.clone(), reference_flow, inject_noise=False) M_FG_warped.append( torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1)) M_FG_warped = torch.cat(M_FG_warped, dim=0) channels = x0.shape[1] M_BG = (1-M_FG) * (1 - M_FG_warped) M_BG = repeat(M_BG, "b f h w -> b c f h w", c=channels) a_convex = smooth_bg_strength latents = (1-M_BG) * x_t1 + M_BG * (a_convex * x_t1 + (1-a_convex) * x_t1_1_fg_masked_moved) ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=latents, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() del ddim_res del latents latents = x0 # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") torch.cuda.empty_cache() if output_type == "latent": image = latents has_nsfw_concept = None else: image = self.decode_latents(latents) # Run safety checker image, has_nsfw_concept = self.run_safety_checker( image, device, text_embeddings.dtype) image = rearrange(image, "b c f h w -> (b f) h w c") # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: return (image, has_nsfw_concept) return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)