from diffusers import StableDiffusionPipeline import torch from dataclasses import dataclass from typing import Callable, List, Optional, Union import numpy as np from diffusers.utils import deprecate, logging, BaseOutput from einops import rearrange, repeat from torch.nn.functional import grid_sample import torchvision.transforms as T from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker @dataclass class TextToVideoPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] code: Union[torch.Tensor, np.ndarray] def coords_grid(batch, ht, wd, device): # Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) class TextToVideoPipeline(StableDiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, ): #super().__init__(*args,**kwargs) super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker) def DDPM_forward(self, x0, t0, tMax, generator, device, shape, text_embeddings): rand_device = "cpu" if device.type == "mps" else device if x0 is None: return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device) else: eps = torch.randn_like(x0, dtype=text_embeddings.dtype).to(device) alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax]) xt = torch.sqrt(alpha_vec) * x0 + \ torch.sqrt(1-alpha_vec) * eps return xt def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: rand_device = "cpu" if device.type == "mps" else device if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn( shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn( shape, generator=generator, device=rand_device, dtype=dtype).to(device) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def warp_latents(self, latents, reference_flow): _, _, H, W = reference_flow.size() b, c, f, h, w = latents.size() coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = T.Resize((h, w))(coords_t0) coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c') latents_0 = latents[:, :, 0] latents_0 = latents_0.repeat(f, 1, 1, 1) warped = grid_sample(latents_0, coords_t0, mode='nearest', padding_mode='reflection') warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f) return warped def warp_latents_independently(self, latents, reference_flow): _, _, H, W = reference_flow.size() b, c, f, h, w = latents.size() assert b == 1 coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype) coords_t0 = coords0 + reference_flow coords_t0[:, 0] /= W coords_t0[:, 1] /= H coords_t0 = coords_t0 * 2.0 - 1.0 coords_t0 = T.Resize((h, w))(coords_t0) coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c') latents_0 = rearrange(latents[0], 'c f h w -> f c h w') warped = grid_sample(latents_0, coords_t0, mode='nearest', padding_mode='reflection') warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f) return warped def DDIM_backward(self, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, null_embs, text_embeddings, latents_local, latents_dtype, guidance_scale, guidance_stop_step, callback, callback_steps, extra_step_kwargs, num_warmup_steps): entered = False f = latents_local.shape[2] latents_local = rearrange(latents_local,"b c f w h -> (b f) c w h") latents = latents_local.detach().clone() x_t0_1 = None x_t1_1 = None with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if t > skip_t: continue else: if not entered: print( f"Continue DDIM with i = {i}, t = {t}, latent = {latents.shape}, device = {latents.device}, type = {latents.dtype}") entered = True latents = latents.detach() # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) # predict the noise residual with torch.no_grad(): if null_embs is not None: text_embeddings[0] = null_embs[i][0] te = torch.cat([repeat(text_embeddings[0,:,:], "c k -> f c k",f=f),repeat(text_embeddings[1,:,:], "c k -> f c k",f=f)]) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=te).sample.to(dtype=latents_dtype) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk( 2) noise_pred = noise_pred_uncond + guidance_scale * \ (noise_pred_text - noise_pred_uncond) if i >= guidance_stop_step * len(timesteps): alpha = 0 # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs).prev_sample # latents = latents - alpha * grads / (torch.norm(grads) + 1e-10) # call the callback, if provided if i < len(timesteps)-1 and timesteps[i+1] == t0: x_t0_1 = latents.detach().clone() print(f"latent t0 found at i = {i}, t = {t}") elif i < len(timesteps)-1 and timesteps[i+1] == t1: x_t1_1 = latents.detach().clone() print(f"latent t1 found at i={i}, t = {t}") if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) latents = rearrange(latents,"(b f) c w h -> b c f w h",f = f) res = {"x0": latents.detach().clone()} if x_t0_1 is not None: x_t0_1 = rearrange(x_t0_1,"(b f) c w h -> b c f w h",f = f) res["x_t0_1"] = x_t0_1.detach().clone() if x_t1_1 is not None: x_t1_1 = rearrange(x_t1_1,"(b f) c w h -> b c f w h",f = f) res["x_t1_1"] = x_t1_1.detach().clone() return res def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / 0.18215 * latents latents = rearrange(latents, "b c f h w -> (b f) c h w") video = self.vae.decode(latents).sample video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video = (video / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 return video @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], video_length: Optional[int], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, guidance_stop_step: float = 0.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, xT: Optional[torch.FloatTensor] = None, null_embs: Optional[torch.FloatTensor] = None, #motion_field_strength: float = 12, motion_field_strength_x: float = 12, motion_field_strength_y: float = 12, output_type: Optional[str] = "tensor", return_dict: bool = True, callback: Optional[Callable[[ int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, use_motion_field: bool = True, smooth_bg: bool = True, smooth_bg_strength: float = 0.4, **kwargs, ): print(motion_field_strength_x,motion_field_strength_y) print(f" Use: Motion field = {use_motion_field}") print(f" Use: Background smoothing = {smooth_bg}") # Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # Encode input prompt text_embeddings = self._encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt ) # Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # print(f" Latent shape = {latents.shape}") # Prepare latent variables num_channels_latents = self.unet.in_channels xT = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, video_length, height, width, text_embeddings.dtype, device, generator, xT, ) dtype = xT.dtype # when motion field is not used, augment with random latent codes if use_motion_field: xT = xT[:, :, :1] else: if xT.shape[2] < video_length: xT_missing = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, video_length-xT.shape[2], height, width, text_embeddings.dtype, device, generator, None, ) xT = torch.cat([xT, xT_missing], dim=2) xInit = xT.clone() t0 = kwargs["t0"] t1 = kwargs["t1"] x_t1_1 = None # Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # Denoising loop num_warmup_steps = len(timesteps) - \ num_inference_steps * self.scheduler.order ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() if "x_t0_1" in ddim_res: x_t0_1 = ddim_res["x_t0_1"].detach() if "x_t1_1" in ddim_res: x_t1_1 = ddim_res["x_t1_1"].detach() del ddim_res del xT if use_motion_field: del x0 shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor) x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1) reference_flow = torch.zeros( (video_length-1, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype) for fr_idx in range(video_length-1): #reference_flow[fr_idx, :, :, :] = motion_field_strength*(fr_idx+1) reference_flow[fr_idx, 0, :, :] = motion_field_strength_x*(fr_idx+1) reference_flow[fr_idx, 1, :, :] = motion_field_strength_y*(fr_idx+1) for idx, latent in enumerate(x_t0_k): x_t0_k[idx] = self.warp_latents_independently( latent[None], reference_flow) # assuming t0=t1=1000, if t0 = 1000 if t1 > t0: x_t1_k = self.DDPM_forward( x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) else: x_t1_k = x_t0_k if x_t1_1 is None: raise Exception x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2).clone().detach() ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() del ddim_res else: x_t1 = x_t1_1.clone() x_t1_1 = x_t1_1[:,:,:1,:,:].clone() x_t1_k = x_t1_1[:,:,1:,:,:].clone() x_t0_k = x_t0_1[:, :, 1:, :, :].clone() x_t0_1 = x_t0_1[:,:,:1,:,:].clone() # smooth background if smooth_bg: h, w = x0.shape[3], x0.shape[4] M_FG = torch.zeros((batch_size, video_length, h, w), device=x0.device).to(x0.dtype) for batch_idx, x0_b in enumerate(x0): z0_b = self.decode_latents(x0_b[None]).detach() z0_b = rearrange(z0_b[0], "c f h w -> f h w c") for frame_idx, z0_f in enumerate(z0_b): z0_f = torch.round( z0_f * 255).cpu().numpy().astype(np.uint8) # apply SOD detection m_f = torch.tensor(self.sod_model.process_data( z0_f), device=x0.device).to(x0.dtype) mask = T.Resize( size=(h, w), interpolation=T.InterpolationMode.NEAREST)(m_f[None]) kernel = torch.ones(5, 5, device=x0.device, dtype=x0.dtype) mask = dilation(mask[None].to(x0.device), kernel)[0] M_FG[batch_idx, frame_idx, :, :] = mask x_t1_1_fg_masked = x_t1_1 * \ (1 - repeat(M_FG[:, 0, :, :], "b w h -> b c 1 w h", c=x_t1_1.shape[1])) x_t1_1_fg_masked_moved = [] for batch_idx, x_t1_1_fg_masked_b in enumerate(x_t1_1_fg_masked): x_t1_fg_masked_b = x_t1_1_fg_masked_b.clone() x_t1_fg_masked_b = x_t1_fg_masked_b.repeat( 1, video_length-1, 1, 1) if use_motion_field: x_t1_fg_masked_b = x_t1_fg_masked_b[None] x_t1_fg_masked_b = self.warp_latents_independently( x_t1_fg_masked_b, reference_flow) else: x_t1_fg_masked_b = x_t1_fg_masked_b[None] x_t1_fg_masked_b = torch.cat( [x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2) x_t1_1_fg_masked_moved.append(x_t1_fg_masked_b) x_t1_1_fg_masked_moved = torch.cat(x_t1_1_fg_masked_moved, dim=0) M_FG_1 = M_FG[:, :1, :, :] M_FG_warped = [] for batch_idx, m_fg_1_b in enumerate(M_FG_1): m_fg_1_b = m_fg_1_b[None, None] m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1) if use_motion_field: m_fg_b = self.warp_latents_independently( m_fg_b.clone(), reference_flow) M_FG_warped.append( torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1)) M_FG_warped = torch.cat(M_FG_warped, dim=0) channels = x0.shape[1] M_BG = (1-M_FG) * (1 - M_FG_warped) M_BG = repeat(M_BG, "b f h w -> b c f h w", c=channels) a_convex = smooth_bg_strength x_t1_blending = (1-M_BG) * x_t1 + M_BG * (a_convex * x_t1 + (1-a_convex) * x_t1_1_fg_masked_moved) ''' x_t1_blending = self.DDPM_forward( x0=x_t1_blending, t0=t1, tMax=961, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator) t1 = 961 ''' latents = x_t1_blending ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance, null_embs=null_embs, text_embeddings=text_embeddings, latents_local=latents, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps) x0 = ddim_res["x0"].detach() del ddim_res # Post-processing video_list = [] for latent in x0: tmp = latent[None] print("Frame spit shape", tmp.shape) frames = [] for fr_split in range(tmp.shape[2]): print("frame decoding") frames.append(self.decode_latents( tmp[:, :, fr_split, None]).detach()) video_list.append(torch.cat(frames, dim=2).cpu().float().numpy()) # Convert to tensor videos = [] if output_type == "tensor": for video in video_list: videos.append(torch.from_numpy(video)) if output_type == 'numpy': for video in video_list: videos.append(rearrange(video, 'b c f h w -> (b f) h w c')) if not return_dict: return video return TextToVideoPipelineOutput(videos=videos, code=torch.split(xInit.detach().cpu(), 1, dim=0))