Spaces:
Paused
Paused
import math | |
import torch | |
import torch.nn as nn | |
# attention_channels of input, output, middle | |
SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 | |
SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20 | |
class ImageProjModel(torch.nn.Module): | |
"""Projection Model""" | |
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): | |
super().__init__() | |
self.cross_attention_dim = cross_attention_dim | |
self.clip_extra_context_tokens = clip_extra_context_tokens | |
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) | |
self.norm = torch.nn.LayerNorm(cross_attention_dim) | |
def forward(self, image_embeds): | |
embeds = image_embeds | |
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, | |
self.cross_attention_dim) | |
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) | |
return clip_extra_context_tokens | |
# Cross Attention to_k, to_v for IPAdapter | |
class To_KV(torch.nn.Module): | |
def __init__(self, cross_attention_dim): | |
super().__init__() | |
channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS | |
self.to_kvs = torch.nn.ModuleList( | |
[torch.nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels]) | |
def load_state_dict(self, state_dict): | |
# input -> output -> middle | |
for i, key in enumerate(state_dict.keys()): | |
self.to_kvs[i].weight.data = state_dict[key] | |
def FeedForward(dim, mult=4): | |
inner_dim = int(dim * mult) | |
return nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, inner_dim, bias=False), | |
nn.GELU(), | |
nn.Linear(inner_dim, dim, bias=False), | |
) | |
def reshape_tensor(x, heads): | |
bs, length, width = x.shape | |
#(bs, length, width) --> (bs, length, n_heads, dim_per_head) | |
x = x.view(bs, length, heads, -1) | |
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) | |
x = x.transpose(1, 2) | |
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) | |
x = x.reshape(bs, heads, length, -1) | |
return x | |
class PerceiverAttention(nn.Module): | |
def __init__(self, *, dim, dim_head=64, heads=8): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.dim_head = dim_head | |
self.heads = heads | |
inner_dim = dim_head * heads | |
self.norm1 = nn.LayerNorm(dim) | |
self.norm2 = nn.LayerNorm(dim) | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | |
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
def forward(self, x, latents): | |
""" | |
Args: | |
x (torch.Tensor): image features | |
shape (b, n1, D) | |
latent (torch.Tensor): latent features | |
shape (b, n2, D) | |
""" | |
x = self.norm1(x) | |
latents = self.norm2(latents) | |
b, l, _ = latents.shape | |
q = self.to_q(latents) | |
kv_input = torch.cat((x, latents), dim=-2) | |
k, v = self.to_kv(kv_input).chunk(2, dim=-1) | |
q = reshape_tensor(q, self.heads) | |
k = reshape_tensor(k, self.heads) | |
v = reshape_tensor(v, self.heads) | |
# attention | |
scale = 1 / math.sqrt(math.sqrt(self.dim_head)) | |
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards | |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) | |
out = weight @ v | |
out = out.permute(0, 2, 1, 3).reshape(b, l, -1) | |
return self.to_out(out) | |
class Resampler(nn.Module): | |
def __init__( | |
self, | |
dim=1024, | |
depth=8, | |
dim_head=64, | |
heads=16, | |
num_queries=8, | |
embedding_dim=768, | |
output_dim=1024, | |
ff_mult=4, | |
): | |
super().__init__() | |
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) | |
self.proj_in = nn.Linear(embedding_dim, dim) | |
self.proj_out = nn.Linear(dim, output_dim) | |
self.norm_out = nn.LayerNorm(output_dim) | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), | |
FeedForward(dim=dim, mult=ff_mult), | |
] | |
) | |
) | |
def forward(self, x): | |
latents = self.latents.repeat(x.size(0), 1, 1) | |
x = self.proj_in(x) | |
for attn, ff in self.layers: | |
latents = attn(x, latents) + latents | |
latents = ff(latents) + latents | |
latents = self.proj_out(latents) | |
return self.norm_out(latents) | |
class IPAdapterModel(torch.nn.Module): | |
def __init__(self, state_dict, clip_embeddings_dim, is_plus): | |
super().__init__() | |
self.device = "cpu" | |
# cross_attention_dim is equal to text_encoder output | |
self.cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] | |
self.is_plus = is_plus | |
if self.is_plus: | |
self.clip_extra_context_tokens = 16 | |
self.image_proj_model = Resampler( | |
dim=self.cross_attention_dim, | |
depth=4, | |
dim_head=64, | |
heads=12, | |
num_queries=self.clip_extra_context_tokens, | |
embedding_dim=clip_embeddings_dim, | |
output_dim=self.cross_attention_dim, | |
ff_mult=4 | |
) | |
else: | |
self.clip_extra_context_tokens = state_dict["image_proj"]["proj.weight"].shape[0] // self.cross_attention_dim | |
self.image_proj_model = ImageProjModel( | |
cross_attention_dim=self.cross_attention_dim, | |
clip_embeddings_dim=clip_embeddings_dim, | |
clip_extra_context_tokens=self.clip_extra_context_tokens | |
) | |
self.load_ip_adapter(state_dict) | |
def load_ip_adapter(self, state_dict): | |
self.image_proj_model.load_state_dict(state_dict["image_proj"]) | |
self.ip_layers = To_KV(self.cross_attention_dim) | |
self.ip_layers.load_state_dict(state_dict["ip_adapter"]) | |
def get_image_embeds(self, clip_vision_output): | |
self.image_proj_model.cpu() | |
if self.is_plus: | |
from annotator.clipvision import clip_vision_h_uc | |
cond = self.image_proj_model(clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32)) | |
uncond = self.image_proj_model(clip_vision_h_uc.to(cond)) | |
return cond, uncond | |
clip_image_embeds = clip_vision_output['image_embeds'].to(device='cpu', dtype=torch.float32) | |
image_prompt_embeds = self.image_proj_model(clip_image_embeds) | |
# input zero vector for unconditional. | |
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) | |
return image_prompt_embeds, uncond_image_prompt_embeds | |
def get_block(model, flag): | |
return { | |
'input': model.input_blocks, 'middle': [model.middle_block], 'output': model.output_blocks | |
}[flag] | |
def attn_forward_hacked(self, x, context=None, **kwargs): | |
batch_size, sequence_length, inner_dim = x.shape | |
h = self.heads | |
head_dim = inner_dim // h | |
if context is None: | |
context = x | |
q = self.to_q(x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
del context | |
q, k, v = map( | |
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), | |
(q, k, v), | |
) | |
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) | |
out = out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) | |
del k, v | |
for f in self.ipadapter_hacks: | |
out = out + f(self, x, q) | |
del q, x | |
return self.to_out(out) | |
all_hacks = {} | |
current_model = None | |
def hack_blk(block, function, type): | |
if not hasattr(block, 'ipadapter_hacks'): | |
block.ipadapter_hacks = [] | |
if len(block.ipadapter_hacks) == 0: | |
all_hacks[block] = block.forward | |
block.forward = attn_forward_hacked.__get__(block, type) | |
block.ipadapter_hacks.append(function) | |
return | |
def set_model_attn2_replace(model, function, flag, id): | |
from ldm.modules.attention import CrossAttention | |
block = get_block(model, flag)[id][1].transformer_blocks[0].attn2 | |
hack_blk(block, function, CrossAttention) | |
return | |
def set_model_patch_replace(model, function, flag, id, trans_id): | |
from sgm.modules.attention import CrossAttention | |
blk = get_block(model, flag) | |
block = blk[id][1].transformer_blocks[trans_id].attn2 | |
hack_blk(block, function, CrossAttention) | |
return | |
def clear_all_ip_adapter(): | |
global all_hacks, current_model | |
for k, v in all_hacks.items(): | |
k.forward = v | |
k.ipadapter_hacks = [] | |
all_hacks = {} | |
current_model = None | |
return | |
class PlugableIPAdapter(torch.nn.Module): | |
def __init__(self, state_dict, clip_embeddings_dim, is_plus): | |
super().__init__() | |
self.sdxl = clip_embeddings_dim == 1280 and not is_plus | |
self.is_plus = is_plus | |
self.ipadapter = IPAdapterModel(state_dict, clip_embeddings_dim=clip_embeddings_dim, is_plus=is_plus) | |
self.disable_memory_management = True | |
self.dtype = None | |
self.weight = 1.0 | |
self.cache = {} | |
self.p_start = 0.0 | |
self.p_end = 1.0 | |
return | |
def reset(self): | |
self.cache = {} | |
return | |
def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float32): | |
global current_model | |
current_model = model | |
self.p_start = start | |
self.p_end = end | |
self.cache = {} | |
self.weight = weight | |
device = torch.device('cpu') | |
self.dtype = dtype | |
self.ipadapter.to(device, dtype=self.dtype) | |
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output) | |
self.image_emb = self.image_emb.to(device, dtype=self.dtype) | |
self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype) | |
# From https://github.com/laksjdjf/IPAdapter-ComfyUI | |
if not self.sdxl: | |
number = 0 # index of to_kvs | |
for id in [1, 2, 4, 5, 7, 8]: # id of input_blocks that have cross attention | |
set_model_attn2_replace(model, self.patch_forward(number), "input", id) | |
number += 1 | |
for id in [3, 4, 5, 6, 7, 8, 9, 10, 11]: # id of output_blocks that have cross attention | |
set_model_attn2_replace(model, self.patch_forward(number), "output", id) | |
number += 1 | |
set_model_attn2_replace(model, self.patch_forward(number), "middle", 0) | |
else: | |
number = 0 | |
for id in [4, 5, 7, 8]: # id of input_blocks that have cross attention | |
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth | |
for index in block_indices: | |
set_model_patch_replace(model, self.patch_forward(number), "input", id, index) | |
number += 1 | |
for id in range(6): # id of output_blocks that have cross attention | |
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth | |
for index in block_indices: | |
set_model_patch_replace(model, self.patch_forward(number), "output", id, index) | |
number += 1 | |
for index in range(10): | |
set_model_patch_replace(model, self.patch_forward(number), "middle", 0, index) | |
number += 1 | |
return | |
def call_ip(self, number, feat, device): | |
if number in self.cache: | |
return self.cache[number] | |
else: | |
ip = self.ipadapter.ip_layers.to_kvs[number](feat).to(device) | |
self.cache[number] = ip | |
return ip | |
def patch_forward(self, number): | |
def forward(attn_blk, x, q): | |
batch_size, sequence_length, inner_dim = x.shape | |
h = attn_blk.heads | |
head_dim = inner_dim // h | |
current_sampling_percent = getattr(current_model, 'current_sampling_percent', 0.5) | |
if current_sampling_percent < self.p_start or current_sampling_percent > self.p_end: | |
return 0 | |
cond_mark = current_model.cond_mark[:, :, :, 0].to(self.image_emb) | |
cond_uncond_image_emb = self.image_emb * cond_mark + self.uncond_image_emb * (1 - cond_mark) | |
ip_k = self.call_ip(number * 2, cond_uncond_image_emb, device=q.device) | |
ip_v = self.call_ip(number * 2 + 1, cond_uncond_image_emb, device=q.device) | |
ip_k, ip_v = map( | |
lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), | |
(ip_k, ip_v), | |
) | |
ip_out = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False) | |
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) | |
return ip_out * self.weight | |
return forward | |