ymzhang319's picture
init
7f2690b
raw
history blame
8.43 kB
import torch
import torch.nn as nn
import numpy as np
import os
from typing import List
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from foleycrafter.models.adapters.resampler import Resampler
from foleycrafter.models.adapters.utils import is_torch2_available
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
class VideoProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=1, video_frame=50):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
self.video_frame = video_frame
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class MLPProjModel(torch.nn.Module):
"""SD model with image prompt"""
def zero_initialize(module):
for param in module.parameters():
param.data.zero_()
def zero_initialize_last_layer(module):
last_layer = None
for module_name, layer in module.named_modules():
if isinstance(layer, torch.nn.Linear):
last_layer = layer
if last_layer is not None:
last_layer.weight.data.zero_()
last_layer.bias.data.zero_()
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim)
)
# zero initialize the last layer
# self.zero_initialize_last_layer()
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class V2AMapperMLP(torch.nn.Module):
def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult),
torch.nn.GELU(),
torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim),
torch.nn.LayerNorm(cross_attention_dim)
)
def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
class TimeProjModel(torch.nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums:int=64):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.position_dim = frame_nums
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
if feature_type == "text-only":
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
elif feature_type == "text-image":
self.linears_text = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.linears_image = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
# self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
def forward(
self,
boxes,
masks,
positive_embeddings=None,
):
masks = masks.unsqueeze(-1)
# # embedding position (it may includes padding as placeholder)
# xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
# # learnable null embedding
# xyxy_null = self.null_position_feature.view(1, 1, -1)
# # replace padding with learnable null embedding
# xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
time_embeds = boxes
# positionet with text only information
if positive_embeddings is not None:
# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1))
# positionet with text and image infomation
else:
raise NotImplementedError
return objs