Text2Video-Zero / text_to_video_pipeline.py
lev1's picture
Back to diffusers 0.14.x
cbea3f9
from diffusers import StableDiffusionPipeline
import torch
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy as np
from diffusers.utils import deprecate, logging, BaseOutput
from einops import rearrange, repeat
from torch.nn.functional import grid_sample
import torchvision.transforms as T
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
import PIL
from PIL import Image
from kornia.morphology import dilation
@dataclass
class TextToVideoPipelineOutput(BaseOutput):
# videos: Union[torch.Tensor, np.ndarray]
# code: Union[torch.Tensor, np.ndarray]
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
def coords_grid(batch, ht, wd, device):
# Adapted from https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
coords = torch.meshgrid(torch.arange(
ht, device=device), torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
class TextToVideoPipeline(StableDiffusionPipeline):
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
safety_checker, feature_extractor, requires_safety_checker)
def DDPM_forward(self, x0, t0, tMax, generator, device, shape, text_embeddings):
rand_device = "cpu" if device.type == "mps" else device
if x0 is None:
return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device)
else:
eps = torch.randn(x0.shape, dtype=text_embeddings.dtype, generator=generator,
device=rand_device)
alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax])
xt = torch.sqrt(alpha_vec) * x0 + \
torch.sqrt(1-alpha_vec) * eps
return xt
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, video_length, height //
self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
rand_device = "cpu" if device.type == "mps" else device
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(
shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(
shape, generator=generator, device=rand_device, dtype=dtype).to(device)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def warp_latents_independently(self, latents, reference_flow):
_, _, H, W = reference_flow.size()
b, _, f, h, w = latents.size()
assert b == 1
coords0 = coords_grid(f, H, W, device=latents.device).to(latents.dtype)
coords_t0 = coords0 + reference_flow
coords_t0[:, 0] /= W
coords_t0[:, 1] /= H
coords_t0 = coords_t0 * 2.0 - 1.0
coords_t0 = T.Resize((h, w))(coords_t0)
coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c')
latents_0 = rearrange(latents[0], 'c f h w -> f c h w')
warped = grid_sample(latents_0, coords_t0,
mode='nearest', padding_mode='reflection')
warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f)
return warped
def DDIM_backward(self, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, null_embs, text_embeddings, latents_local,
latents_dtype, guidance_scale, guidance_stop_step, callback, callback_steps, extra_step_kwargs, num_warmup_steps):
entered = False
f = latents_local.shape[2]
latents_local = rearrange(latents_local, "b c f w h -> (b f) c w h")
latents = latents_local.detach().clone()
x_t0_1 = None
x_t1_1 = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if t > skip_t:
continue
else:
if not entered:
print(
f"Continue DDIM with i = {i}, t = {t}, latent = {latents.shape}, device = {latents.device}, type = {latents.dtype}")
entered = True
latents = latents.detach()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t)
# predict the noise residual
with torch.no_grad():
if null_embs is not None:
text_embeddings[0] = null_embs[i][0]
te = torch.cat([repeat(text_embeddings[0, :, :], "c k -> f c k", f=f),
repeat(text_embeddings[1, :, :], "c k -> f c k", f=f)])
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=te).sample.to(dtype=latents_dtype)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(
2)
noise_pred = noise_pred_uncond + guidance_scale * \
(noise_pred_text - noise_pred_uncond)
if i >= guidance_stop_step * len(timesteps):
alpha = 0
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs).prev_sample
# latents = latents - alpha * grads / (torch.norm(grads) + 1e-10)
# call the callback, if provided
if i < len(timesteps)-1 and timesteps[i+1] == t0:
x_t0_1 = latents.detach().clone()
print(f"latent t0 found at i = {i}, t = {t}")
elif i < len(timesteps)-1 and timesteps[i+1] == t1:
x_t1_1 = latents.detach().clone()
print(f"latent t1 found at i={i}, t = {t}")
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
latents = rearrange(latents, "(b f) c w h -> b c f w h", f=f)
res = {"x0": latents.detach().clone()}
if x_t0_1 is not None:
x_t0_1 = rearrange(x_t0_1, "(b f) c w h -> b c f w h", f=f)
res["x_t0_1"] = x_t0_1.detach().clone()
if x_t1_1 is not None:
x_t1_1 = rearrange(x_t1_1, "(b f) c w h -> b c f w h", f=f)
res["x_t1_1"] = x_t1_1.detach().clone()
return res
def decode_latents(self, latents):
video_length = latents.shape[2]
latents = 1 / 0.18215 * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
video = self.vae.decode(latents).sample
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.detach().cpu()
return video
def create_motion_field(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents):
reference_flow = torch.zeros(
(video_length-1, 2, 512, 512), device=latents.device, dtype=latents.dtype)
for fr_idx, frame_id in enumerate(frame_ids):
reference_flow[fr_idx, 0, :,
:] = motion_field_strength_x*(frame_id)
reference_flow[fr_idx, 1, :,
:] = motion_field_strength_y*(frame_id)
return reference_flow
def create_motion_field_and_warp_latents(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents):
motion_field = self.create_motion_field(motion_field_strength_x=motion_field_strength_x,
motion_field_strength_y=motion_field_strength_y, latents=latents, video_length=video_length, frame_ids=frame_ids)
for idx, latent in enumerate(latents):
latents[idx] = self.warp_latents_independently(
latent[None], motion_field)
return motion_field, latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
video_length: Optional[int],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
guidance_stop_step: float = 0.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
xT: Optional[torch.FloatTensor] = None,
null_embs: Optional[torch.FloatTensor] = None,
motion_field_strength_x: float = 12,
motion_field_strength_y: float = 12,
output_type: Optional[str] = "tensor",
return_dict: bool = True,
callback: Optional[Callable[[
int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
use_motion_field: bool = True,
smooth_bg: bool = False,
smooth_bg_strength: float = 0.4,
t0: int = 44,
t1: int = 47,
**kwargs,
):
frame_ids = kwargs.pop("frame_ids", list(range(video_length)))
assert t0 < t1
assert num_videos_per_prompt == 1
assert isinstance(prompt, list) and len(prompt) > 0
assert isinstance(negative_prompt, list) or negative_prompt is None
prompt_types = [prompt, negative_prompt]
for idx, prompt_type in enumerate(prompt_types):
prompt_template = None
for prompt in prompt_type:
if prompt_template is None:
prompt_template = prompt
else:
assert prompt == prompt_template
if prompt_types[idx] is not None:
prompt_types[idx] = prompt_types[idx][0]
prompt = prompt_types[0]
negative_prompt = prompt_types[1]
# Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
)
# Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# print(f" Latent shape = {latents.shape}")
# Prepare latent variables
num_channels_latents = self.unet.in_channels
xT = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
1,
height,
width,
text_embeddings.dtype,
device,
generator,
xT,
)
dtype = xT.dtype
# when motion field is not used, augment with random latent codes
if use_motion_field:
xT = xT[:, :, :1]
else:
if xT.shape[2] < video_length:
xT_missing = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
video_length-xT.shape[2],
height,
width,
text_embeddings.dtype,
device,
generator,
None,
)
xT = torch.cat([xT, xT_missing], dim=2)
xInit = xT.clone()
timesteps_ddpm = [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721,
701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441,
421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161,
141, 121, 101, 81, 61, 41, 21, 1]
timesteps_ddpm.reverse()
t0 = timesteps_ddpm[t0]
t1 = timesteps_ddpm[t1]
print(f"t0 = {t0} t1 = {t1}")
x_t1_1 = None
# Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# Denoising loop
num_warmup_steps = len(timesteps) - \
num_inference_steps * self.scheduler.order
shape = (batch_size, num_channels_latents, 1, height //
self.vae_scale_factor, width // self.vae_scale_factor)
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance,
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step,
callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
x0 = ddim_res["x0"].detach()
if "x_t0_1" in ddim_res:
x_t0_1 = ddim_res["x_t0_1"].detach()
if "x_t1_1" in ddim_res:
x_t1_1 = ddim_res["x_t1_1"].detach()
del ddim_res
del xT
if use_motion_field:
del x0
x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1)
reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(
motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_t0_k, video_length=video_length, frame_ids=frame_ids[1:])
# assuming t0=t1=1000, if t0 = 1000
if t1 > t0:
x_t1_k = self.DDPM_forward(
x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator)
else:
x_t1_k = x_t0_k
if x_t1_1 is None:
raise Exception
x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2).clone().detach()
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale,
guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
x0 = ddim_res["x0"].detach()
del ddim_res
del x_t1
del x_t1_1
del x_t1_k
else:
x_t1 = x_t1_1.clone()
x_t1_1 = x_t1_1[:, :, :1, :, :].clone()
x_t1_k = x_t1_1[:, :, 1:, :, :].clone()
x_t0_k = x_t0_1[:, :, 1:, :, :].clone()
x_t0_1 = x_t0_1[:, :, :1, :, :].clone()
# smooth background
if smooth_bg:
h, w = x0.shape[3], x0.shape[4]
M_FG = torch.zeros((batch_size, video_length, h, w),
device=x0.device).to(x0.dtype)
for batch_idx, x0_b in enumerate(x0):
z0_b = self.decode_latents(x0_b[None]).detach()
z0_b = rearrange(z0_b[0], "c f h w -> f h w c")
for frame_idx, z0_f in enumerate(z0_b):
z0_f = torch.round(
z0_f * 255).cpu().numpy().astype(np.uint8)
# apply SOD detection
m_f = torch.tensor(self.sod_model.process_data(
z0_f), device=x0.device).to(x0.dtype)
mask = T.Resize(
size=(h, w), interpolation=T.InterpolationMode.NEAREST)(m_f[None])
kernel = torch.ones(5, 5, device=x0.device, dtype=x0.dtype)
mask = dilation(mask[None].to(x0.device), kernel)[0]
M_FG[batch_idx, frame_idx, :, :] = mask
x_t1_1_fg_masked = x_t1_1 * \
(1 - repeat(M_FG[:, 0, :, :],
"b w h -> b c 1 w h", c=x_t1_1.shape[1]))
x_t1_1_fg_masked_moved = []
for batch_idx, x_t1_1_fg_masked_b in enumerate(x_t1_1_fg_masked):
x_t1_fg_masked_b = x_t1_1_fg_masked_b.clone()
x_t1_fg_masked_b = x_t1_fg_masked_b.repeat(
1, video_length-1, 1, 1)
if use_motion_field:
x_t1_fg_masked_b = x_t1_fg_masked_b[None]
x_t1_fg_masked_b = self.warp_latents_independently(
x_t1_fg_masked_b, reference_flow)
else:
x_t1_fg_masked_b = x_t1_fg_masked_b[None]
x_t1_fg_masked_b = torch.cat(
[x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2)
x_t1_1_fg_masked_moved.append(x_t1_fg_masked_b)
x_t1_1_fg_masked_moved = torch.cat(x_t1_1_fg_masked_moved, dim=0)
M_FG_1 = M_FG[:, :1, :, :]
M_FG_warped = []
for batch_idx, m_fg_1_b in enumerate(M_FG_1):
m_fg_1_b = m_fg_1_b[None, None]
m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1)
if use_motion_field:
m_fg_b = self.warp_latents_independently(
m_fg_b.clone(), reference_flow)
M_FG_warped.append(
torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1))
M_FG_warped = torch.cat(M_FG_warped, dim=0)
channels = x0.shape[1]
M_BG = (1-M_FG) * (1 - M_FG_warped)
M_BG = repeat(M_BG, "b f h w -> b c f h w", c=channels)
a_convex = smooth_bg_strength
latents = (1-M_BG) * x_t1 + M_BG * (a_convex *
x_t1 + (1-a_convex) * x_t1_1_fg_masked_moved)
ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
null_embs=null_embs, text_embeddings=text_embeddings, latents_local=latents, latents_dtype=dtype, guidance_scale=guidance_scale,
guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
x0 = ddim_res["x0"].detach()
del ddim_res
del latents
latents = x0
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
torch.cuda.empty_cache()
if output_type == "latent":
image = latents
has_nsfw_concept = None
else:
image = self.decode_latents(latents)
# Run safety checker
image, has_nsfw_concept = self.run_safety_checker(
image, device, text_embeddings.dtype)
image = rearrange(image, "b c f h w -> (b f) h w c")
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return TextToVideoPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)