Spaces:
Sleeping
Sleeping
from typing import Any, Dict, Optional | |
import torch | |
from torch import nn | |
from diffusers.models.attention import Attention | |
from diffusers.utils.import_utils import is_xformers_available | |
from einops import rearrange, repeat | |
import math | |
import torch.nn.functional as F | |
if is_xformers_available(): | |
import xformers | |
import xformers.ops | |
else: | |
xformers = None | |
class RowwiseMVAttention(Attention): | |
def set_use_memory_efficient_attention_xformers( | |
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
): | |
processor = XFormersMVAttnProcessor() | |
self.set_processor(processor) | |
# print("using xformers attention processor") | |
class IPCDAttention(Attention): | |
def set_use_memory_efficient_attention_xformers( | |
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
): | |
processor = XFormersIPCDAttnProcessor() | |
self.set_processor(processor) | |
# print("using xformers attention processor") | |
class XFormersMVAttnProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
num_views=1, | |
multiview_attention=True, | |
cd_attention_mid=False | |
): | |
# print(num_views) | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
height = int(math.sqrt(sequence_length)) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# from yuancheng; here attention_mask is None | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if attn.group_norm is not None: | |
print('Warning: using group norm, pay attention to use it in row-wise attention') | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key_raw = attn.to_k(encoder_hidden_states) | |
value_raw = attn.to_v(encoder_hidden_states) | |
# print('query', query.shape, 'key', key.shape, 'value', value.shape) | |
# pdb.set_trace() | |
def transpose(tensor): | |
tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) | |
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c | |
tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c | |
tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) | |
return tensor | |
# print(mvcd_attention) | |
# import pdb;pdb.set_trace() | |
if cd_attention_mid: | |
key = transpose(key_raw) | |
value = transpose(value_raw) | |
query = transpose(query) | |
else: | |
key = rearrange(key_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
value = rearrange(value_raw, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) | |
query = attn.head_to_batch_dim(query) # torch.Size([960, 384, 64]) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if cd_attention_mid: | |
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) | |
hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c | |
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c | |
hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) | |
else: | |
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class XFormersIPCDAttnProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def process(self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
num_tasks=2, | |
num_views=6): | |
### TODO: num_views | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
height = int(math.sqrt(sequence_length)) | |
height_st = height // 3 | |
height_end = height - height_st | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
# from yuancheng; here attention_mask is None | |
if attention_mask is not None: | |
# expand our mask's singleton query_tokens dimension: | |
# [batch*heads, 1, key_tokens] -> | |
# [batch*heads, query_tokens, key_tokens] | |
# so that it can be added as a bias onto the attention scores that xformers computes: | |
# [batch*heads, query_tokens, key_tokens] | |
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | |
_, query_tokens, _ = hidden_states.shape | |
attention_mask = attention_mask.expand(-1, query_tokens, -1) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
assert num_tasks == 2 # only support two tasks now | |
# ip attn | |
# hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) | |
# body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] | |
# print(body_hidden_states.shape, face_hidden_states.shape) | |
# import pdb;pdb.set_trace() | |
# hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) | |
# hidden_states = rearrange( | |
# torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), | |
# 'b v l c -> (b v) l c') | |
# face cross attention | |
# ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) | |
# ip_key = attn.to_k_ip(ip_hidden_states) | |
# ip_value = attn.to_v_ip(ip_hidden_states) | |
# ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
# ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
# ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() | |
# ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) | |
# ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) | |
# ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) | |
# ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) | |
# import pdb;pdb.set_trace() | |
def transpose(tensor): | |
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c | |
tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c | |
# tensor = rearrange(tensor, "(b v) l c -> b v l c", v=num_views+1) | |
# body, face = tensor[:, :-1, :], tensor[:, -1:, :] # b,v,l,c; b,1,l,c | |
# face = face.repeat(1, num_views, 1, 1) # b,v,l,c | |
# tensor = torch.cat([body, face], dim=2) # b, v, 4hw, c | |
# tensor = rearrange(tensor, "b v l c -> (b v) l c") | |
return tensor | |
key = transpose(key) | |
value = transpose(value) | |
query = transpose(query) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states_normal, hidden_states_color = torch.chunk(hidden_states, dim=1, chunks=2) # bv, hw, c | |
hidden_states_normal = rearrange(hidden_states_normal, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) | |
face_normal = rearrange(hidden_states_normal[:, -1, :, :, :], 'b h w c -> b c h w').detach() | |
face_normal = rearrange(F.interpolate(face_normal, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') | |
hidden_states_normal = hidden_states_normal.clone() # Create a copy of hidden_states_normal | |
hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_normal | |
# hidden_states_normal[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_normal[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_normal | |
hidden_states_normal = rearrange(hidden_states_normal, "b v h w c -> (b v) (h w) c") | |
hidden_states_color = rearrange(hidden_states_color, "(b v) (h w) c -> b v h w c", v=num_views+1, h=height) | |
face_color = rearrange(hidden_states_color[:, -1, :, :, :], 'b h w c -> b c h w').detach() | |
face_color = rearrange(F.interpolate(face_color, size=(height_st, height_st), mode='bilinear'), 'b c h w -> b h w c') | |
hidden_states_color = hidden_states_color.clone() # Create a copy of hidden_states_color | |
hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.5 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.5 * face_color | |
# hidden_states_color[:, 0, :height_st, height_st:height_end, :] = 0.1 * hidden_states_color[:, 0, :height_st, height_st:height_end, :] + 0.9 * face_color | |
hidden_states_color = rearrange(hidden_states_color, "b v h w c -> (b v) (h w) c") | |
hidden_states = torch.cat([hidden_states_normal, hidden_states_color], dim=0) # 2bv hw c | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
num_tasks=2, | |
): | |
hidden_states = self.process(attn, hidden_states, encoder_hidden_states, attention_mask, temb, num_tasks) | |
# hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c') | |
# body_hidden_states, head_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1:, :, :] | |
# import pdb;pdb.set_trace() | |
# hidden_states = body_hidden_states + attn.ip_scale * head_hidden_states.detach().repeat(1, views, 1, 1) | |
# hidden_states = rearrange( | |
# torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states], dim=1), | |
# 'b v l c -> (b v) l c') | |
return hidden_states | |
class IPCrossAttn(Attention): | |
r""" | |
Attention processor for IP-Adapater. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
The context length of the image features. | |
""" | |
def __init__(self, | |
query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, ip_scale=1.0): | |
super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention) | |
self.ip_scale = ip_scale | |
# self.num_tokens = num_tokens | |
# self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
# self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
# self.to_out_ip = nn.ModuleList([]) | |
# self.to_out_ip.append(nn.Linear(self.inner_dim, self.inner_dim, bias=bias)) | |
# self.to_out_ip.append(nn.Dropout(dropout)) | |
# nn.init.zeros_(self.to_k_ip.weight.data) | |
# nn.init.zeros_(self.to_v_ip.weight.data) | |
def set_use_memory_efficient_attention_xformers( | |
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
): | |
processor = XFormersIPCrossAttnProcessor() | |
self.set_processor(processor) | |
class XFormersIPCrossAttnProcessor: | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
num_views=1 | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# ip attn | |
# hidden_states = rearrange(hidden_states, '(b v) l c -> b v l c', v=num_views) | |
# body_hidden_states, face_hidden_states = rearrange(hidden_states[:, :-1, :, :], 'b v l c -> (b v) l c'), hidden_states[:, -1, :, :] | |
# print(body_hidden_states.shape, face_hidden_states.shape) | |
# import pdb;pdb.set_trace() | |
# hidden_states = body_hidden_states + attn.ip_scale * repeat(head_hidden_states.detach(), 'b l c -> (b v) l c', v=n_view) | |
# hidden_states = rearrange( | |
# torch.cat([rearrange(hidden_states, '(b v) l c -> b v l c'), head_hidden_states.unsqueeze(1)], dim=1), | |
# 'b v l c -> (b v) l c') | |
# face cross attention | |
# ip_hidden_states = repeat(face_hidden_states.detach(), 'b l c -> (b v) l c', v=num_views-1) | |
# ip_key = attn.to_k_ip(ip_hidden_states) | |
# ip_value = attn.to_v_ip(ip_hidden_states) | |
# ip_key = attn.head_to_batch_dim(ip_key).contiguous() | |
# ip_value = attn.head_to_batch_dim(ip_value).contiguous() | |
# ip_query = attn.head_to_batch_dim(body_hidden_states).contiguous() | |
# ip_hidden_states = xformers.ops.memory_efficient_attention(ip_query, ip_key, ip_value, attn_bias=attention_mask) | |
# ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) | |
# ip_hidden_states = attn.to_out_ip[0](ip_hidden_states) | |
# ip_hidden_states = attn.to_out_ip[1](ip_hidden_states) | |
# import pdb;pdb.set_trace() | |
# body_hidden_states = body_hidden_states + attn.ip_scale * ip_hidden_states | |
# hidden_states = rearrange( | |
# torch.cat([rearrange(body_hidden_states, '(b v) l c -> b v l c', v=num_views-1), face_hidden_states.unsqueeze(1)], dim=1), | |
# 'b v l c -> (b v) l c') | |
# import pdb;pdb.set_trace() | |
# | |
hidden_states = attn.to_out[0](hidden_states) | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
# TODO: region control | |
# region control | |
# if len(region_control.prompt_image_conditioning) == 1: | |
# region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) | |
# if region_mask is not None: | |
# h, w = region_mask.shape[:2] | |
# ratio = (h * w / query.shape[1]) ** 0.5 | |
# mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) | |
# else: | |
# mask = torch.ones_like(ip_hidden_states) | |
# ip_hidden_states = ip_hidden_states * mask | |
return hidden_states | |
class RowwiseMVProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
num_views=1, | |
cd_attention_mid=False | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
height = int(math.sqrt(sequence_length)) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
# print('query', query.shape, 'key', key.shape, 'value', value.shape) | |
#([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320]) | |
# pdb.set_trace() | |
# multi-view self-attention | |
def transpose(tensor): | |
tensor = rearrange(tensor, "(b v) (h w) c -> b v h w c", v=num_views, h=height) | |
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # b v h w c | |
tensor = torch.cat([tensor_0, tensor_1], dim=3) # b v h 2w c | |
tensor = rearrange(tensor, "b v h w c -> (b h) (v w) c", v=num_views, h=height) | |
return tensor | |
if cd_attention_mid: | |
key = transpose(key) | |
value = transpose(value) | |
query = transpose(query) | |
else: | |
key = rearrange(key, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
value = rearrange(value, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) | |
query = rearrange(query, "(b v) (h w) c -> (b h) (v w) c", v=num_views, h=height) # torch.Size([192, 384, 320]) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if cd_attention_mid: | |
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> b v h w c", v=num_views, h=height) | |
hidden_states_0, hidden_states_1 = torch.chunk(hidden_states, dim=3, chunks=2) # b v h w c | |
hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=0) # 2b v h w c | |
hidden_states = rearrange(hidden_states, "b v h w c -> (b v) (h w) c", v=num_views, h=height) | |
else: | |
hidden_states = rearrange(hidden_states, "(b h) (v w) c -> (b v) (h w) c", v=num_views, h=height) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class CDAttention(Attention): | |
# def __init__(self, ip_scale, | |
# query_dim, heads, dim_head, dropout, bias, cross_attention_dim, upcast_attention, processor): | |
# super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, processor=processor) | |
# self.ip_scale = ip_scale | |
# self.to_k_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
# self.to_v_ip = nn.Linear(query_dim, self.inner_dim, bias=False) | |
# nn.init.zeros_(self.to_k_ip.weight.data) | |
# nn.init.zeros_(self.to_v_ip.weight.data) | |
def set_use_memory_efficient_attention_xformers( | |
self, use_memory_efficient_attention_xformers: bool, *args, **kwargs | |
): | |
processor = XFormersCDAttnProcessor() | |
self.set_processor(processor) | |
# print("using xformers attention processor") | |
class XFormersCDAttnProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
num_tasks=2 | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
assert num_tasks == 2 # only support two tasks now | |
def transpose(tensor): | |
tensor_0, tensor_1 = torch.chunk(tensor, dim=0, chunks=2) # bv hw c | |
tensor = torch.cat([tensor_0, tensor_1], dim=1) # bv 2hw c | |
return tensor | |
key = transpose(key) | |
value = transpose(value) | |
query = transpose(query) | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = torch.cat([hidden_states[:, 0], hidden_states[:, 1]], dim=0) # 2bv hw c | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |