QuintW's picture
Upload 1350 files
3f9c56c
raw
history blame
13.6 kB
import math
import torch
import torch.nn as nn
# attention_channels of input, output, middle
SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2
SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens,
self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
# Cross Attention to_k, to_v for IPAdapter
class To_KV(torch.nn.Module):
def __init__(self, cross_attention_dim):
super().__init__()
channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS
self.to_kvs = torch.nn.ModuleList(
[torch.nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels])
def load_state_dict(self, state_dict):
# input -> output -> middle
for i, key in enumerate(state_dict.keys()):
self.to_kvs[i].weight.data = state_dict[key]
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
class IPAdapterModel(torch.nn.Module):
def __init__(self, state_dict, clip_embeddings_dim, is_plus):
super().__init__()
self.device = "cpu"
# cross_attention_dim is equal to text_encoder output
self.cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
self.is_plus = is_plus
if self.is_plus:
self.clip_extra_context_tokens = 16
self.image_proj_model = Resampler(
dim=self.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.clip_extra_context_tokens,
embedding_dim=clip_embeddings_dim,
output_dim=self.cross_attention_dim,
ff_mult=4
)
else:
self.clip_extra_context_tokens = state_dict["image_proj"]["proj.weight"].shape[0] // self.cross_attention_dim
self.image_proj_model = ImageProjModel(
cross_attention_dim=self.cross_attention_dim,
clip_embeddings_dim=clip_embeddings_dim,
clip_extra_context_tokens=self.clip_extra_context_tokens
)
self.load_ip_adapter(state_dict)
def load_ip_adapter(self, state_dict):
self.image_proj_model.load_state_dict(state_dict["image_proj"])
self.ip_layers = To_KV(self.cross_attention_dim)
self.ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, clip_vision_output):
self.image_proj_model.cpu()
if self.is_plus:
from annotator.clipvision import clip_vision_h_uc
cond = self.image_proj_model(clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32))
uncond = self.image_proj_model(clip_vision_h_uc.to(cond))
return cond, uncond
clip_image_embeds = clip_vision_output['image_embeds'].to(device='cpu', dtype=torch.float32)
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
# input zero vector for unconditional.
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def get_block(model, flag):
return {
'input': model.input_blocks, 'middle': [model.middle_block], 'output': model.output_blocks
}[flag]
def attn_forward_hacked(self, x, context=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape
h = self.heads
head_dim = inner_dim // h
if context is None:
context = x
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
del context
q, k, v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
del k, v
for f in self.ipadapter_hacks:
out = out + f(self, x, q)
del q, x
return self.to_out(out)
all_hacks = {}
current_model = None
def hack_blk(block, function, type):
if not hasattr(block, 'ipadapter_hacks'):
block.ipadapter_hacks = []
if len(block.ipadapter_hacks) == 0:
all_hacks[block] = block.forward
block.forward = attn_forward_hacked.__get__(block, type)
block.ipadapter_hacks.append(function)
return
def set_model_attn2_replace(model, function, flag, id):
from ldm.modules.attention import CrossAttention
block = get_block(model, flag)[id][1].transformer_blocks[0].attn2
hack_blk(block, function, CrossAttention)
return
def set_model_patch_replace(model, function, flag, id, trans_id):
from sgm.modules.attention import CrossAttention
blk = get_block(model, flag)
block = blk[id][1].transformer_blocks[trans_id].attn2
hack_blk(block, function, CrossAttention)
return
def clear_all_ip_adapter():
global all_hacks, current_model
for k, v in all_hacks.items():
k.forward = v
k.ipadapter_hacks = []
all_hacks = {}
current_model = None
return
class PlugableIPAdapter(torch.nn.Module):
def __init__(self, state_dict, clip_embeddings_dim, is_plus):
super().__init__()
self.sdxl = clip_embeddings_dim == 1280 and not is_plus
self.is_plus = is_plus
self.ipadapter = IPAdapterModel(state_dict, clip_embeddings_dim=clip_embeddings_dim, is_plus=is_plus)
self.disable_memory_management = True
self.dtype = None
self.weight = 1.0
self.cache = {}
self.p_start = 0.0
self.p_end = 1.0
return
def reset(self):
self.cache = {}
return
@torch.no_grad()
def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float32):
global current_model
current_model = model
self.p_start = start
self.p_end = end
self.cache = {}
self.weight = weight
device = torch.device('cpu')
self.dtype = dtype
self.ipadapter.to(device, dtype=self.dtype)
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output)
self.image_emb = self.image_emb.to(device, dtype=self.dtype)
self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype)
# From https://github.com/laksjdjf/IPAdapter-ComfyUI
if not self.sdxl:
number = 0 # index of to_kvs
for id in [1, 2, 4, 5, 7, 8]: # id of input_blocks that have cross attention
set_model_attn2_replace(model, self.patch_forward(number), "input", id)
number += 1
for id in [3, 4, 5, 6, 7, 8, 9, 10, 11]: # id of output_blocks that have cross attention
set_model_attn2_replace(model, self.patch_forward(number), "output", id)
number += 1
set_model_attn2_replace(model, self.patch_forward(number), "middle", 0)
else:
number = 0
for id in [4, 5, 7, 8]: # id of input_blocks that have cross attention
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(model, self.patch_forward(number), "input", id, index)
number += 1
for id in range(6): # id of output_blocks that have cross attention
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(model, self.patch_forward(number), "output", id, index)
number += 1
for index in range(10):
set_model_patch_replace(model, self.patch_forward(number), "middle", 0, index)
number += 1
return
def call_ip(self, number, feat, device):
if number in self.cache:
return self.cache[number]
else:
ip = self.ipadapter.ip_layers.to_kvs[number](feat).to(device)
self.cache[number] = ip
return ip
@torch.no_grad()
def patch_forward(self, number):
@torch.no_grad()
def forward(attn_blk, x, q):
batch_size, sequence_length, inner_dim = x.shape
h = attn_blk.heads
head_dim = inner_dim // h
current_sampling_percent = getattr(current_model, 'current_sampling_percent', 0.5)
if current_sampling_percent < self.p_start or current_sampling_percent > self.p_end:
return 0
cond_mark = current_model.cond_mark[:, :, :, 0].to(self.image_emb)
cond_uncond_image_emb = self.image_emb * cond_mark + self.uncond_image_emb * (1 - cond_mark)
ip_k = self.call_ip(number * 2, cond_uncond_image_emb, device=q.device)
ip_v = self.call_ip(number * 2 + 1, cond_uncond_image_emb, device=q.device)
ip_k, ip_v = map(
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2),
(ip_k, ip_v),
)
ip_out = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
return ip_out * self.weight
return forward