Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import einops | |
from backend.args import args | |
from backend import memory_management | |
from backend.misc.sub_quadratic_attention import efficient_dot_product_attention | |
BROKEN_XFORMERS = False | |
if memory_management.xformers_enabled(): | |
import xformers | |
import xformers.ops | |
try: | |
x_vers = xformers.__version__ | |
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") | |
except: | |
pass | |
FORCE_UPCAST_ATTENTION_DTYPE = memory_management.force_upcast_attention_dtype() | |
def get_attn_precision(attn_precision=torch.float32): | |
if args.disable_attention_upcast: | |
return None | |
if FORCE_UPCAST_ATTENTION_DTYPE is not None: | |
return FORCE_UPCAST_ATTENTION_DTYPE | |
return attn_precision | |
def exists(val): | |
return val is not None | |
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): | |
attn_precision = get_attn_precision(attn_precision) | |
if skip_reshape: | |
b, _, _, dim_head = q.shape | |
else: | |
b, _, dim_head = q.shape | |
dim_head //= heads | |
scale = dim_head ** -0.5 | |
h = heads | |
if skip_reshape: | |
q, k, v = map( | |
lambda t: t.reshape(b * heads, -1, dim_head), | |
(q, k, v), | |
) | |
else: | |
q, k, v = map( | |
lambda t: t.unsqueeze(3) | |
.reshape(b, -1, heads, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b * heads, -1, dim_head) | |
.contiguous(), | |
(q, k, v), | |
) | |
if attn_precision == torch.float32: | |
sim = torch.einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale | |
else: | |
sim = torch.einsum('b i d, b j d -> b i j', q, k) * scale | |
del q, k | |
if exists(mask): | |
if mask.dtype == torch.bool: | |
mask = einops.rearrange(mask, 'b ... -> b (...)') | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = einops.repeat(mask, 'b j -> (b h) () j', h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
else: | |
if len(mask.shape) == 2: | |
bs = 1 | |
else: | |
bs = mask.shape[0] | |
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) | |
sim.add_(mask) | |
sim = sim.softmax(dim=-1) | |
out = torch.einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) | |
out = ( | |
out.unsqueeze(0) | |
.reshape(b, heads, -1, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b, -1, heads * dim_head) | |
) | |
return out | |
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False): | |
attn_precision = get_attn_precision(attn_precision) | |
if skip_reshape: | |
b, _, _, dim_head = query.shape | |
else: | |
b, _, dim_head = query.shape | |
dim_head //= heads | |
scale = dim_head ** -0.5 | |
if skip_reshape: | |
query = query.reshape(b * heads, -1, dim_head) | |
value = value.reshape(b * heads, -1, dim_head) | |
key = key.reshape(b * heads, -1, dim_head).movedim(1, 2) | |
else: | |
query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) | |
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) | |
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) | |
dtype = query.dtype | |
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 | |
if upcast_attention: | |
bytes_per_token = torch.finfo(torch.float32).bits // 8 | |
else: | |
bytes_per_token = torch.finfo(query.dtype).bits // 8 | |
batch_x_heads, q_tokens, _ = query.shape | |
_, _, k_tokens = key.shape | |
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens | |
mem_free_total, mem_free_torch = memory_management.get_free_memory(query.device, True) | |
kv_chunk_size_min = None | |
kv_chunk_size = None | |
query_chunk_size = None | |
for x in [4096, 2048, 1024, 512, 256]: | |
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0) | |
if count >= k_tokens: | |
kv_chunk_size = k_tokens | |
query_chunk_size = x | |
break | |
if query_chunk_size is None: | |
query_chunk_size = 512 | |
if mask is not None: | |
if len(mask.shape) == 2: | |
bs = 1 | |
else: | |
bs = mask.shape[0] | |
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) | |
hidden_states = efficient_dot_product_attention( | |
query, | |
key, | |
value, | |
query_chunk_size=query_chunk_size, | |
kv_chunk_size=kv_chunk_size, | |
kv_chunk_size_min=kv_chunk_size_min, | |
use_checkpoint=False, | |
upcast_attention=upcast_attention, | |
mask=mask, | |
) | |
hidden_states = hidden_states.to(dtype) | |
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2) | |
return hidden_states | |
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): | |
attn_precision = get_attn_precision(attn_precision) | |
if skip_reshape: | |
b, _, _, dim_head = q.shape | |
else: | |
b, _, dim_head = q.shape | |
dim_head //= heads | |
scale = dim_head ** -0.5 | |
h = heads | |
if skip_reshape: | |
q, k, v = map( | |
lambda t: t.reshape(b * heads, -1, dim_head), | |
(q, k, v), | |
) | |
else: | |
q, k, v = map( | |
lambda t: t.unsqueeze(3) | |
.reshape(b, -1, heads, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b * heads, -1, dim_head) | |
.contiguous(), | |
(q, k, v), | |
) | |
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) | |
mem_free_total = memory_management.get_free_memory(q.device) | |
if attn_precision == torch.float32: | |
element_size = 4 | |
upcast = True | |
else: | |
element_size = q.element_size() | |
upcast = False | |
gb = 1024 ** 3 | |
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size | |
modifier = 3 | |
mem_required = tensor_size * modifier | |
steps = 1 | |
if mem_required > mem_free_total: | |
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) | |
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " | |
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") | |
if steps > 64: | |
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 | |
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' | |
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') | |
if mask is not None: | |
if len(mask.shape) == 2: | |
bs = 1 | |
else: | |
bs = mask.shape[0] | |
mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) | |
# print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) | |
first_op_done = False | |
cleared_cache = False | |
while True: | |
try: | |
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] | |
for i in range(0, q.shape[1], slice_size): | |
end = i + slice_size | |
if upcast: | |
with torch.autocast(enabled=False, device_type='cuda'): | |
s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale | |
else: | |
s1 = torch.einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale | |
if mask is not None: | |
if len(mask.shape) == 2: | |
s1 += mask[i:end] | |
else: | |
s1 += mask[:, i:end] | |
s2 = s1.softmax(dim=-1).to(v.dtype) | |
del s1 | |
first_op_done = True | |
r1[:, i:end] = torch.einsum('b i j, b j d -> b i d', s2, v) | |
del s2 | |
break | |
except memory_management.OOM_EXCEPTION as e: | |
if first_op_done == False: | |
memory_management.soft_empty_cache(True) | |
if cleared_cache == False: | |
cleared_cache = True | |
print("out of memory error, emptying cache and trying again") | |
continue | |
steps *= 2 | |
if steps > 64: | |
raise e | |
print("out of memory error, increasing steps and trying again {}".format(steps)) | |
else: | |
raise e | |
del q, k, v | |
r1 = ( | |
r1.unsqueeze(0) | |
.reshape(b, heads, -1, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b, -1, heads * dim_head) | |
) | |
return r1 | |
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): | |
if skip_reshape: | |
b, _, _, dim_head = q.shape | |
else: | |
b, _, dim_head = q.shape | |
dim_head //= heads | |
if BROKEN_XFORMERS and b * heads > 65535: | |
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) | |
if skip_reshape: | |
q, k, v = map( | |
lambda t: t.reshape(b * heads, -1, dim_head), | |
(q, k, v), | |
) | |
else: | |
q, k, v = map( | |
lambda t: t.reshape(b, -1, heads, dim_head), | |
(q, k, v), | |
) | |
if mask is not None: | |
pad = 8 - q.shape[1] % 8 | |
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) | |
mask_out[:, :, :mask.shape[-1]] = mask | |
mask = mask_out[:, :, :mask.shape[-1]] | |
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) | |
if skip_reshape: | |
out = ( | |
out.unsqueeze(0) | |
.reshape(b, heads, -1, dim_head) | |
.permute(0, 2, 1, 3) | |
.reshape(b, -1, heads * dim_head) | |
) | |
else: | |
out = ( | |
out.reshape(b, -1, heads * dim_head) | |
) | |
return out | |
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): | |
if skip_reshape: | |
b, _, _, dim_head = q.shape | |
else: | |
b, _, dim_head = q.shape | |
dim_head //= heads | |
q, k, v = map( | |
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), | |
(q, k, v), | |
) | |
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) | |
out = ( | |
out.transpose(1, 2).reshape(b, -1, heads * dim_head) | |
) | |
return out | |
def slice_attention_single_head_spatial(q, k, v): | |
r1 = torch.zeros_like(k, device=q.device) | |
scale = (int(q.shape[-1]) ** (-0.5)) | |
mem_free_total = memory_management.get_free_memory(q.device) | |
gb = 1024 ** 3 | |
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() | |
modifier = 3 if q.element_size() == 2 else 2.5 | |
mem_required = tensor_size * modifier | |
steps = 1 | |
if mem_required > mem_free_total: | |
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) | |
while True: | |
try: | |
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] | |
for i in range(0, q.shape[1], slice_size): | |
end = i + slice_size | |
s1 = torch.bmm(q[:, i:end], k) * scale | |
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0, 2, 1) | |
del s1 | |
r1[:, :, i:end] = torch.bmm(v, s2) | |
del s2 | |
break | |
except memory_management.OOM_EXCEPTION as e: | |
memory_management.soft_empty_cache(True) | |
steps *= 2 | |
if steps > 128: | |
raise e | |
print("out of memory error, increasing steps and trying again {}".format(steps)) | |
return r1 | |
def normal_attention_single_head_spatial(q, k, v): | |
# compute attention | |
b, c, h, w = q.shape | |
q = q.reshape(b, c, h * w) | |
q = q.permute(0, 2, 1) # b,hw,c | |
k = k.reshape(b, c, h * w) # b,c,hw | |
v = v.reshape(b, c, h * w) | |
r1 = slice_attention_single_head_spatial(q, k, v) | |
h_ = r1.reshape(b, c, h, w) | |
del r1 | |
return h_ | |
def xformers_attention_single_head_spatial(q, k, v): | |
# compute attention | |
B, C, H, W = q.shape | |
q, k, v = map( | |
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), | |
(q, k, v), | |
) | |
try: | |
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) | |
out = out.transpose(1, 2).reshape(B, C, H, W) | |
except NotImplementedError as e: | |
out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), | |
v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) | |
return out | |
def pytorch_attention_single_head_spatial(q, k, v): | |
# compute attention | |
B, C, H, W = q.shape | |
q, k, v = map( | |
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), | |
(q, k, v), | |
) | |
try: | |
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) | |
out = out.transpose(2, 3).reshape(B, C, H, W) | |
except memory_management.OOM_EXCEPTION as e: | |
print("scaled_dot_product_attention OOMed: switched to slice attention") | |
out = slice_attention_single_head_spatial(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), | |
v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) | |
return out | |
if memory_management.xformers_enabled(): | |
print("Using xformers cross attention") | |
attention_function = attention_xformers | |
elif memory_management.pytorch_attention_enabled(): | |
print("Using pytorch cross attention") | |
attention_function = attention_pytorch | |
elif args.attention_split: | |
print("Using split optimization for cross attention") | |
attention_function = attention_split | |
else: | |
print("Using sub quadratic optimization for cross attention") | |
attention_function = attention_sub_quad | |
if memory_management.xformers_enabled_vae(): | |
print("Using xformers attention for VAE") | |
attention_function_single_head_spatial = xformers_attention_single_head_spatial | |
elif memory_management.pytorch_attention_enabled(): | |
print("Using pytorch attention for VAE") | |
attention_function_single_head_spatial = pytorch_attention_single_head_spatial | |
else: | |
print("Using split attention for VAE") | |
attention_function_single_head_spatial = normal_attention_single_head_spatial | |
class AttentionProcessorForge: | |
def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None, temb=None, *args, **kwargs): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
hidden_states = attention_function(query, key, value, heads=attn.heads, mask=attention_mask) | |
hidden_states = attn.to_out[0](hidden_states) | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |