Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,158 Bytes
6dd488f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import torch
import einops
from diffusers import DiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer
from huggingface_hub import snapshot_download
from diffusers_vdm.vae import VideoAutoencoderKL
from diffusers_vdm.projection import Resampler
from diffusers_vdm.unet import UNet3DModel
from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection
from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR
class LatentVideoDiffusionPipeline(DiffusionPipeline):
def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True):
super().__init__()
self.loading_components = dict(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
image_encoder=image_encoder,
image_projection=image_projection
)
for k, v in self.loading_components.items():
setattr(self, k, v)
if fp16:
self.vae.half()
self.text_encoder.half()
self.unet.half()
self.image_encoder.half()
self.image_projection.half()
self.vae.requires_grad_(False)
self.text_encoder.requires_grad_(False)
self.image_encoder.requires_grad_(False)
self.vae.eval()
self.text_encoder.eval()
self.image_encoder.eval()
if eval:
self.unet.eval()
self.image_projection.eval()
else:
self.unet.train()
self.image_projection.train()
def to(self, *args, **kwargs):
for k, v in self.loading_components.items():
if hasattr(v, 'to'):
v.to(*args, **kwargs)
return self
def save_pretrained(self, save_directory, **kwargs):
for k, v in self.loading_components.items():
folder = os.path.join(save_directory, k)
os.makedirs(folder, exist_ok=True)
v.save_pretrained(folder)
return
@classmethod
def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None):
local_folder = snapshot_download(repo_id=repo_id, token=token)
return cls(
tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")),
text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")),
image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")),
vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")),
image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")),
unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")),
fp16=fp16,
eval=eval
)
@torch.inference_mode()
def encode_cropped_prompt_77tokens(self, prompt: str):
cond_ids = self.tokenizer(prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt").input_ids.to(self.text_encoder.device)
cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state
return cond
@torch.inference_mode()
def encode_clip_vision(self, frames):
b, c, t, h, w = frames.shape
frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w')
clipvision_embed = self.image_encoder(frames).last_hidden_state
clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t)
return clipvision_embed
@torch.inference_mode()
def encode_latents(self, videos, return_hidden_states=True):
b, c, t, h, w = videos.shape
x = einops.rearrange(videos, 'b c t h w -> (b t) c h w')
encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states)
z = encoder_posterior.mode() * self.vae.scale_factor
z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
if not return_hidden_states:
return z
hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states]
hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last
return z, hidden_states
@torch.inference_mode()
def decode_latents(self, latents, hidden_states):
B, C, T, H, W = latents.shape
latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w')
latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor
pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T)
pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T)
return pixels
@torch.inference_mode()
def __call__(
self,
batch_size: int = 1,
steps: int = 50,
guidance_scale: float = 5.0,
positive_text_cond = None,
negative_text_cond = None,
positive_image_cond = None,
negative_image_cond = None,
concat_cond = None,
fs = 3,
progress_tqdm = None,
):
unet_is_training = self.unet.training
if unet_is_training:
self.unet.eval()
device = self.unet.device
dtype = self.unet.dtype
dynamic_tsnr_model = SamplerDynamicTSNR(self.unet)
# Batch
concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w
positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c
positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c
negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond)
if isinstance(fs, torch.Tensor):
fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b
else:
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b
# Initial latents
latent_shape = concat_cond.shape
# Feeds
sampler_kwargs = dict(
cfg_scale=guidance_scale,
positive=dict(
context_text=positive_text_cond,
context_img=positive_image_cond,
fs=fs,
concat_cond=concat_cond
),
negative=dict(
context_text=negative_text_cond,
context_img=negative_image_cond,
fs=fs,
concat_cond=concat_cond
)
)
# Sample
results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm)
if unet_is_training:
self.unet.train()
return results
|