Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import einops | |
| from diffusers import DiffusionPipeline | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from huggingface_hub import snapshot_download | |
| from diffusers_vdm.vae import VideoAutoencoderKL | |
| from diffusers_vdm.projection import Resampler | |
| from diffusers_vdm.unet import UNet3DModel | |
| from diffusers_vdm.improved_clip_vision import ImprovedCLIPVisionModelWithProjection | |
| from diffusers_vdm.dynamic_tsnr_sampler import SamplerDynamicTSNR | |
| class LatentVideoDiffusionPipeline(DiffusionPipeline): | |
| def __init__(self, tokenizer, text_encoder, image_encoder, vae, image_projection, unet, fp16=True, eval=True): | |
| super().__init__() | |
| self.loading_components = dict( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| image_encoder=image_encoder, | |
| image_projection=image_projection | |
| ) | |
| for k, v in self.loading_components.items(): | |
| setattr(self, k, v) | |
| if fp16: | |
| self.vae.half() | |
| self.text_encoder.half() | |
| self.unet.half() | |
| self.image_encoder.half() | |
| self.image_projection.half() | |
| self.vae.requires_grad_(False) | |
| self.text_encoder.requires_grad_(False) | |
| self.image_encoder.requires_grad_(False) | |
| self.vae.eval() | |
| self.text_encoder.eval() | |
| self.image_encoder.eval() | |
| if eval: | |
| self.unet.eval() | |
| self.image_projection.eval() | |
| else: | |
| self.unet.train() | |
| self.image_projection.train() | |
| def to(self, *args, **kwargs): | |
| for k, v in self.loading_components.items(): | |
| if hasattr(v, 'to'): | |
| v.to(*args, **kwargs) | |
| return self | |
| def save_pretrained(self, save_directory, **kwargs): | |
| for k, v in self.loading_components.items(): | |
| folder = os.path.join(save_directory, k) | |
| os.makedirs(folder, exist_ok=True) | |
| v.save_pretrained(folder) | |
| return | |
| def from_pretrained(cls, repo_id, fp16=True, eval=True, token=None): | |
| local_folder = snapshot_download(repo_id=repo_id, token=token) | |
| return cls( | |
| tokenizer=CLIPTokenizer.from_pretrained(os.path.join(local_folder, "tokenizer")), | |
| text_encoder=CLIPTextModel.from_pretrained(os.path.join(local_folder, "text_encoder")), | |
| image_encoder=ImprovedCLIPVisionModelWithProjection.from_pretrained(os.path.join(local_folder, "image_encoder")), | |
| vae=VideoAutoencoderKL.from_pretrained(os.path.join(local_folder, "vae")), | |
| image_projection=Resampler.from_pretrained(os.path.join(local_folder, "image_projection")), | |
| unet=UNet3DModel.from_pretrained(os.path.join(local_folder, "unet")), | |
| fp16=fp16, | |
| eval=eval | |
| ) | |
| def encode_cropped_prompt_77tokens(self, prompt: str): | |
| cond_ids = self.tokenizer(prompt, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt").input_ids.to(self.text_encoder.device) | |
| cond = self.text_encoder(cond_ids, attention_mask=None).last_hidden_state | |
| return cond | |
| def encode_clip_vision(self, frames): | |
| b, c, t, h, w = frames.shape | |
| frames = einops.rearrange(frames, 'b c t h w -> (b t) c h w') | |
| clipvision_embed = self.image_encoder(frames).last_hidden_state | |
| clipvision_embed = einops.rearrange(clipvision_embed, '(b t) d c -> b t d c', t=t) | |
| return clipvision_embed | |
| def encode_latents(self, videos, return_hidden_states=True): | |
| b, c, t, h, w = videos.shape | |
| x = einops.rearrange(videos, 'b c t h w -> (b t) c h w') | |
| encoder_posterior, hidden_states = self.vae.encode(x, return_hidden_states=return_hidden_states) | |
| z = encoder_posterior.mode() * self.vae.scale_factor | |
| z = einops.rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) | |
| if not return_hidden_states: | |
| return z | |
| hidden_states = [einops.rearrange(h, '(b t) c h w -> b c t h w', b=b) for h in hidden_states] | |
| hidden_states = [h[:, :, [0, -1], :, :] for h in hidden_states] # only need first and last | |
| return z, hidden_states | |
| def decode_latents(self, latents, hidden_states): | |
| B, C, T, H, W = latents.shape | |
| latents = einops.rearrange(latents, 'b c t h w -> (b t) c h w') | |
| latents = latents.to(device=self.vae.device, dtype=self.vae.dtype) / self.vae.scale_factor | |
| pixels = self.vae.decode(latents, ref_context=hidden_states, timesteps=T) | |
| pixels = einops.rearrange(pixels, '(b t) c h w -> b c t h w', b=B, t=T) | |
| return pixels | |
| def __call__( | |
| self, | |
| batch_size: int = 1, | |
| steps: int = 50, | |
| guidance_scale: float = 5.0, | |
| positive_text_cond = None, | |
| negative_text_cond = None, | |
| positive_image_cond = None, | |
| negative_image_cond = None, | |
| concat_cond = None, | |
| fs = 3, | |
| progress_tqdm = None, | |
| ): | |
| unet_is_training = self.unet.training | |
| if unet_is_training: | |
| self.unet.eval() | |
| device = self.unet.device | |
| dtype = self.unet.dtype | |
| dynamic_tsnr_model = SamplerDynamicTSNR(self.unet) | |
| # Batch | |
| concat_cond = concat_cond.repeat(batch_size, 1, 1, 1, 1).to(device=device, dtype=dtype) # b, c, t, h, w | |
| positive_text_cond = positive_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c | |
| negative_text_cond = negative_text_cond.repeat(batch_size, 1, 1).to(concat_cond) # b, f, c | |
| positive_image_cond = positive_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) # b, t, l, c | |
| negative_image_cond = negative_image_cond.repeat(batch_size, 1, 1, 1).to(concat_cond) | |
| if isinstance(fs, torch.Tensor): | |
| fs = fs.repeat(batch_size, ).to(dtype=torch.long, device=device) # b | |
| else: | |
| fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # b | |
| # Initial latents | |
| latent_shape = concat_cond.shape | |
| # Feeds | |
| sampler_kwargs = dict( | |
| cfg_scale=guidance_scale, | |
| positive=dict( | |
| context_text=positive_text_cond, | |
| context_img=positive_image_cond, | |
| fs=fs, | |
| concat_cond=concat_cond | |
| ), | |
| negative=dict( | |
| context_text=negative_text_cond, | |
| context_img=negative_image_cond, | |
| fs=fs, | |
| concat_cond=concat_cond | |
| ) | |
| ) | |
| # Sample | |
| results = dynamic_tsnr_model(latent_shape, steps, extra_args=sampler_kwargs, progress_tqdm=progress_tqdm) | |
| if unet_is_training: | |
| self.unet.train() | |
| return results | |