|
import torch |
|
import torch.nn.functional as F |
|
from diffusers.models.attention_processor import ( |
|
Attention, |
|
AttnProcessor2_0, |
|
SlicedAttnProcessor, |
|
XFormersAttnProcessor |
|
) |
|
|
|
try: |
|
import xformers.ops |
|
except: |
|
xformers = None |
|
|
|
|
|
loaded_networks = [] |
|
|
|
|
|
def apply_single_hypernetwork( |
|
hypernetwork, hidden_states, encoder_hidden_states |
|
): |
|
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states) |
|
return context_k, context_v |
|
|
|
|
|
def apply_hypernetworks(context_k, context_v, layer=None): |
|
if len(loaded_networks) == 0: |
|
return context_v, context_v |
|
for hypernetwork in loaded_networks: |
|
context_k, context_v = hypernetwork.forward(context_k, context_v) |
|
|
|
context_k = context_k.to(dtype=context_k.dtype) |
|
context_v = context_v.to(dtype=context_k.dtype) |
|
|
|
return context_k, context_v |
|
|
|
|
|
|
|
def xformers_forward( |
|
self: XFormersAttnProcessor, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
): |
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
|
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, sequence_length, batch_size |
|
) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) |
|
|
|
key = attn.to_k(context_k) |
|
value = attn.to_v(context_v) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention( |
|
query, |
|
key, |
|
value, |
|
attn_bias=attention_mask, |
|
op=self.attention_op, |
|
scale=attn.scale, |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
return hidden_states |
|
|
|
|
|
def sliced_attn_forward( |
|
self: SlicedAttnProcessor, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
): |
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, sequence_length, batch_size |
|
) |
|
|
|
query = attn.to_q(hidden_states) |
|
dim = query.shape[-1] |
|
query = attn.head_to_batch_dim(query) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) |
|
|
|
key = attn.to_k(context_k) |
|
value = attn.to_v(context_v) |
|
key = attn.head_to_batch_dim(key) |
|
value = attn.head_to_batch_dim(value) |
|
|
|
batch_size_attention, query_tokens, _ = query.shape |
|
hidden_states = torch.zeros( |
|
(batch_size_attention, query_tokens, dim // attn.heads), |
|
device=query.device, |
|
dtype=query.dtype, |
|
) |
|
|
|
for i in range(batch_size_attention // self.slice_size): |
|
start_idx = i * self.slice_size |
|
end_idx = (i + 1) * self.slice_size |
|
|
|
query_slice = query[start_idx:end_idx] |
|
key_slice = key[start_idx:end_idx] |
|
attn_mask_slice = ( |
|
attention_mask[start_idx:end_idx] if attention_mask is not None else None |
|
) |
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) |
|
|
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) |
|
|
|
hidden_states[start_idx:end_idx] = attn_slice |
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
def v2_0_forward( |
|
self: AttnProcessor2_0, |
|
attn: Attention, |
|
hidden_states, |
|
encoder_hidden_states=None, |
|
attention_mask=None, |
|
): |
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
inner_dim = hidden_states.shape[-1] |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, sequence_length, batch_size |
|
) |
|
|
|
|
|
attention_mask = attention_mask.view( |
|
batch_size, attn.heads, -1, attention_mask.shape[-1] |
|
) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states) |
|
|
|
key = attn.to_k(context_k) |
|
value = attn.to_v(context_v) |
|
|
|
head_dim = inner_dim // attn.heads |
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape( |
|
batch_size, -1, attn.heads * head_dim |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
return hidden_states |
|
|
|
|
|
def replace_attentions_for_hypernetwork(): |
|
import diffusers.models.attention_processor |
|
|
|
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = ( |
|
xformers_forward |
|
) |
|
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = ( |
|
sliced_attn_forward |
|
) |
|
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward |
|
|