moistdio's picture
Upload folder using huggingface_hub
6831a54 verified
import math
import torch
import einops
from backend.args import args
from backend import memory_management
from backend.misc.sub_quadratic_attention import efficient_dot_product_attention
BROKEN_XFORMERS = False
if memory_management.xformers_enabled():
import xformers
import xformers.ops
try:
x_vers = xformers.__version__
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
except:
pass
FORCE_UPCAST_ATTENTION_DTYPE = memory_management.force_upcast_attention_dtype()
def get_attn_precision(attn_precision=torch.float32):
if args.disable_attention_upcast:
return None
if FORCE_UPCAST_ATTENTION_DTYPE is not None:
return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision
def exists(val):
return val is not None
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
if attn_precision == torch.float32:
sim = torch.einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = torch.einsum('b i d, b j d -> b i j', q, k) * scale
del q, k
if exists(mask):
if mask.dtype == torch.bool:
mask = einops.rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = einops.repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
else:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
sim.add_(mask)
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
b, _, _, dim_head = query.shape
else:
b, _, dim_head = query.shape
dim_head //= heads
scale = dim_head ** -0.5
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2)
else:
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits // 8
else:
bytes_per_token = torch.finfo(query.dtype).bits // 8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = memory_management.get_free_memory(query.device, True)
kv_chunk_size_min = None
kv_chunk_size = None
query_chunk_size = None
for x in [4096, 2048, 1024, 512, 256]:
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
if count >= k_tokens:
kv_chunk_size = k_tokens
query_chunk_size = x
break
if query_chunk_size is None:
query_chunk_size = 512
if mask is not None:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
hidden_states = efficient_dot_product_attention(
query,
key,
value,
query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min=kv_chunk_size_min,
use_checkpoint=False,
upcast_attention=upcast_attention,
mask=mask,
)
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = memory_management.get_free_memory(q.device)
if attn_precision == torch.float32:
element_size = 4
upcast = True
else:
element_size = q.element_size()
upcast = False
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
if mask is not None:
if len(mask.shape) == 2:
bs = 1
else:
bs = mask.shape[0]
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1])
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
first_op_done = False
cleared_cache = False
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if upcast:
with torch.autocast(enabled=False, device_type='cuda'):
s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
r1[:, i:end] = torch.einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except memory_management.OOM_EXCEPTION as e:
if first_op_done == False:
memory_management.soft_empty_cache(True)
if cleared_cache == False:
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
raise e
print("out of memory error, increasing steps and trying again {}".format(steps))
else:
raise e
del q, k, v
r1 = (
r1.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
if BROKEN_XFORMERS and b * heads > 65535:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
else:
q, k, v = map(
lambda t: t.reshape(b, -1, heads, dim_head),
(q, k, v),
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if skip_reshape:
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
else:
out = (
out.reshape(b, -1, heads * dim_head)
)
return out
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
def slice_attention_single_head_spatial(q, k, v):
r1 = torch.zeros_like(k, device=q.device)
scale = (int(q.shape[-1]) ** (-0.5))
mem_free_total = memory_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0, 2, 1)
del s1
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except memory_management.OOM_EXCEPTION as e:
memory_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
raise e
print("out of memory error, increasing steps and trying again {}".format(steps))
return r1
def normal_attention_single_head_spatial(q, k, v):
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
v = v.reshape(b, c, h * w)
r1 = slice_attention_single_head_spatial(q, k, v)
h_ = r1.reshape(b, c, h, w)
del r1
return h_
def xformers_attention_single_head_spatial(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
(q, k, v),
)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e:
out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2),
v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
def pytorch_attention_single_head_spatial(q, k, v):
# compute attention
B, C, H, W = q.shape
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
)
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W)
except memory_management.OOM_EXCEPTION as e:
print("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2),
v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out
if memory_management.xformers_enabled():
print("Using xformers cross attention")
attention_function = attention_xformers
elif memory_management.pytorch_attention_enabled():
print("Using pytorch cross attention")
attention_function = attention_pytorch
elif args.attention_split:
print("Using split optimization for cross attention")
attention_function = attention_split
else:
print("Using sub quadratic optimization for cross attention")
attention_function = attention_sub_quad
if memory_management.xformers_enabled_vae():
print("Using xformers attention for VAE")
attention_function_single_head_spatial = xformers_attention_single_head_spatial
elif memory_management.pytorch_attention_enabled():
print("Using pytorch attention for VAE")
attention_function_single_head_spatial = pytorch_attention_single_head_spatial
else:
print("Using split attention for VAE")
attention_function_single_head_spatial = normal_attention_single_head_spatial
class AttentionProcessorForge:
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None, temb=None, *args, **kwargs):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
hidden_states = attention_function(query, key, value, heads=attn.heads, mask=attention_mask)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states