ImageCritic / src /layers.py
ziheng1234's picture
Upload 39 files
3e8fe6c verified
import inspect
import math
from typing import Callable, List, Optional, Tuple, Union
from einops import rearrange
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from diffusers.models.attention_processor import Attention
# Global variables for attention visualization
step = 0
global_timestep = 0
global_timestep2 = 0
def scaled_dot_product_average_attention_map(query, key, attn_mask=None, is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias.to(attn_weight.device)
attn_weight = attn_weight.mean(dim=(1, 2))
return attn_weight
class LoRALinearLayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4,
network_alpha: Optional[float] = None,
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
self.out_features = out_features
self.in_features = in_features
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class MultiSingleStreamBlockLoraProcessor(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states
class MultiDoubleStreamBlockLoraProcessor(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return (hidden_states, encoder_hidden_states)
class MultiSingleStreamBlockLoraProcessorWithLoss(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_length = 512
length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
# query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
# query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
# key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
# attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
# attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
# attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
# attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states
class MultiDoubleStreamBlockLoraProcessorWithLoss(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
length = hidden_states.shape[-2] // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
encoder_hidden_length = 512
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
attn.attention_probs_query_a_key_noise = attention_probs_query_a_key_noise
attn.attention_probs_query_b_key_noise = attention_probs_query_b_key_noise
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return (hidden_states, encoder_hidden_states)
class MultiDoubleStreamBlockLoraProcessor_visual(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.proj_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i],network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond=False,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `context` projections.
inner_dim = 3072
head_dim = inner_dim // attn.heads
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
length = hidden_states.shape[-2] // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
encoder_hidden_length = 512
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
attention_probs_query_a_key_noise = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
attention_probs_query_b_key_noise = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
if not hasattr(attn, 'attention_probs_query_a_key_noise'):
attn.attention_probs_query_a_key_noise = []
if not hasattr(attn, 'attention_probs_query_b_key_noise'):
attn.attention_probs_query_b_key_noise = []
global global_timestep
attn.attention_probs_query_a_key_noise.append((global_timestep//19, attention_probs_query_a_key_noise))
attn.attention_probs_query_b_key_noise.append((global_timestep//19, attention_probs_query_b_key_noise))
print(f"Global Timestep: {global_timestep//19}")
global_timestep += 1
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# Linear projection (with LoRA weight applied to each proj layer)
hidden_states = attn.to_out[0](hidden_states)
for i in range(self.n_loras):
hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return (hidden_states, encoder_hidden_states)
class MultiSingleStreamBlockLoraProcessor_visual(nn.Module):
def __init__(self, in_features: int, out_features: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, n_loras=1):
super().__init__()
# Initialize a list to store the LoRA layers
self.n_loras = n_loras
self.q_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.k_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.v_loras = nn.ModuleList([
LoRALinearLayer(in_features, out_features, ranks[i], network_alphas[i], device=device, dtype=dtype)
for i in range(n_loras)
])
self.lora_weights = lora_weights
def __call__(self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
use_cond = False,
) -> torch.FloatTensor:
batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
encoder_hidden_length = 512
length = (hidden_states.shape[-2] - encoder_hidden_length) // 3
for i in range(self.n_loras):
query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
if not hasattr(attn, 'attention_probs_query_a_key_noise2'):
attn.attention_probs_query_a_key_noise2 = []
if not hasattr(attn, 'attention_probs_query_b_key_noise2'):
attn.attention_probs_query_b_key_noise2 = []
query_cond_a = query[:, :, encoder_hidden_length+length : encoder_hidden_length+2*length, :]
query_cond_b = query[:, :, encoder_hidden_length+2*length : encoder_hidden_length+3*length, :]
key_noise = key[:, :, encoder_hidden_length:encoder_hidden_length+length, :]
attention_probs_query_a_key_noise2 = scaled_dot_product_average_attention_map(query_cond_a, key_noise, attn_mask=attention_mask, is_causal=False)
attention_probs_query_b_key_noise2 = scaled_dot_product_average_attention_map(query_cond_b, key_noise, attn_mask=attention_mask, is_causal=False)
global global_timestep2
attn.attention_probs_query_a_key_noise2.append((global_timestep//38, attention_probs_query_a_key_noise2))
attn.attention_probs_query_b_key_noise2.append((global_timestep//38, attention_probs_query_b_key_noise2))
print(f"Global Timestep2: {global_timestep2//38}")
global_timestep2 += 1
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
return hidden_states