Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import numpy as np | |
import os | |
from typing import List | |
from diffusers import StableDiffusionPipeline | |
from diffusers.pipelines.controlnet import MultiControlNetModel | |
from PIL import Image | |
from safetensors import safe_open | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from foleycrafter.models.adapters.resampler import Resampler | |
from foleycrafter.models.adapters.utils import is_torch2_available | |
class IPAdapter(torch.nn.Module): | |
"""IP-Adapter""" | |
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): | |
super().__init__() | |
self.unet = unet | |
self.image_proj_model = image_proj_model | |
self.adapter_modules = adapter_modules | |
if ckpt_path is not None: | |
self.load_from_checkpoint(ckpt_path) | |
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): | |
ip_tokens = self.image_proj_model(image_embeds) | |
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) | |
# Predict the noise residual | |
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample | |
return noise_pred | |
def load_from_checkpoint(self, ckpt_path: str): | |
# Calculate original checksums | |
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
# Load state dict for image_proj_model and adapter_modules | |
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | |
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) | |
# Calculate new checksums | |
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) | |
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) | |
# Verify if the weights have changed | |
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" | |
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" | |
print(f"Successfully loaded weights from checkpoint {ckpt_path}") | |
class VideoProjModel(torch.nn.Module): | |
"""Projection Model""" | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=1, video_frame=50): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
self.video_frame = video_frame | |
def forward(self, image_embeds): | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |
class ImageProjModel(torch.nn.Module): | |
"""Projection Model""" | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds): | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds).reshape( | |
-1, self.clip_extra_context_tokens, self.cross_attention_dim | |
) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |
class MLPProjModel(torch.nn.Module): | |
"""SD model with image prompt""" | |
def zero_initialize(module): | |
for param in module.parameters(): | |
param.data.zero_() | |
def zero_initialize_last_layer(module): | |
last_layer = None | |
for module_name, layer in module.named_modules(): | |
if isinstance(layer, torch.nn.Linear): | |
last_layer = layer | |
if last_layer is not None: | |
last_layer.weight.data.zero_() | |
last_layer.bias.data.zero_() | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): | |
super().__init__() | |
self.proj = torch.nn.Sequential( | |
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), | |
torch.nn.GELU(), | |
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), | |
torch.nn.LayerNorm(cross_attention_dim) | |
) | |
# zero initialize the last layer | |
# self.zero_initialize_last_layer() | |
def forward(self, image_embeds): | |
clip_extra_context_tokens = self.proj(image_embeds) | |
return clip_extra_context_tokens | |
class V2AMapperMLP(torch.nn.Module): | |
def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4): | |
super().__init__() | |
self.proj = torch.nn.Sequential( | |
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult), | |
torch.nn.GELU(), | |
torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim), | |
torch.nn.LayerNorm(cross_attention_dim) | |
) | |
def forward(self, image_embeds): | |
clip_extra_context_tokens = self.proj(image_embeds) | |
return clip_extra_context_tokens | |
class TimeProjModel(torch.nn.Module): | |
def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums:int=64): | |
super().__init__() | |
self.positive_len = positive_len | |
self.out_dim = out_dim | |
self.position_dim = frame_nums | |
if isinstance(out_dim, tuple): | |
out_dim = out_dim[0] | |
if feature_type == "text-only": | |
self.linears = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
elif feature_type == "text-image": | |
self.linears_text = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.linears_image = nn.Sequential( | |
nn.Linear(self.positive_len + self.position_dim, 512), | |
nn.SiLU(), | |
nn.Linear(512, 512), | |
nn.SiLU(), | |
nn.Linear(512, out_dim), | |
) | |
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) | |
# self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) | |
def forward( | |
self, | |
boxes, | |
masks, | |
positive_embeddings=None, | |
): | |
masks = masks.unsqueeze(-1) | |
# # embedding position (it may includes padding as placeholder) | |
# xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C | |
# # learnable null embedding | |
# xyxy_null = self.null_position_feature.view(1, 1, -1) | |
# # replace padding with learnable null embedding | |
# xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null | |
time_embeds = boxes | |
# positionet with text only information | |
if positive_embeddings is not None: | |
# learnable null embedding | |
positive_null = self.null_positive_feature.view(1, 1, -1) | |
# replace padding with learnable null embedding | |
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null | |
objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1)) | |
# positionet with text and image infomation | |
else: | |
raise NotImplementedError | |
return objs |