import math from typing import Any from einops import rearrange import torch from diffusers.models.attention_processor import Attention # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 EPSILON = 1e-6 class FlashAttentionFunction(torch.autograd.function.Function): @staticmethod @torch.no_grad() def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): """Algorithm 2 in the paper""" device = q.device dtype = q.dtype max_neg_value = -torch.finfo(q.dtype).max qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) o = torch.zeros_like(q) all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) all_row_maxes = torch.full( (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device ) scale = q.shape[-1] ** -0.5 if mask is None: mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) else: mask = rearrange(mask, "b n -> b 1 1 n") mask = mask.split(q_bucket_size, dim=-1) row_splits = zip( q.split(q_bucket_size, dim=-2), o.split(q_bucket_size, dim=-2), mask, all_row_sums.split(q_bucket_size, dim=-2), all_row_maxes.split(q_bucket_size, dim=-2), ) for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): q_start_index = ind * q_bucket_size - qk_len_diff col_splits = zip( k.split(k_bucket_size, dim=-2), v.split(k_bucket_size, dim=-2), ) for k_ind, (kc, vc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size attn_weights = ( torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale ) if row_mask is not None: attn_weights.masked_fill_(~row_mask, max_neg_value) if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones( (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device ).triu(q_start_index - k_start_index + 1) attn_weights.masked_fill_(causal_mask, max_neg_value) block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) attn_weights -= block_row_maxes exp_weights = torch.exp(attn_weights) if row_mask is not None: exp_weights.masked_fill_(~row_mask, 0.0) block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( min=EPSILON ) new_row_maxes = torch.maximum(block_row_maxes, row_maxes) exp_values = torch.einsum( "... i j, ... j d -> ... i d", exp_weights, vc ) exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) new_row_sums = ( exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums ) oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( (exp_block_row_max_diff / new_row_sums) * exp_values ) row_maxes.copy_(new_row_maxes) row_sums.copy_(new_row_sums) ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) return o @staticmethod @torch.no_grad() def backward(ctx, do): """Algorithm 4 in the paper""" causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args q, k, v, o, l, m = ctx.saved_tensors device = q.device max_neg_value = -torch.finfo(q.dtype).max qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) dq = torch.zeros_like(q) dk = torch.zeros_like(k) dv = torch.zeros_like(v) row_splits = zip( q.split(q_bucket_size, dim=-2), o.split(q_bucket_size, dim=-2), do.split(q_bucket_size, dim=-2), mask, l.split(q_bucket_size, dim=-2), m.split(q_bucket_size, dim=-2), dq.split(q_bucket_size, dim=-2), ) for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): q_start_index = ind * q_bucket_size - qk_len_diff col_splits = zip( k.split(k_bucket_size, dim=-2), v.split(k_bucket_size, dim=-2), dk.split(k_bucket_size, dim=-2), dv.split(k_bucket_size, dim=-2), ) for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): k_start_index = k_ind * k_bucket_size attn_weights = ( torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale ) if causal and q_start_index < (k_start_index + k_bucket_size - 1): causal_mask = torch.ones( (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device ).triu(q_start_index - k_start_index + 1) attn_weights.masked_fill_(causal_mask, max_neg_value) exp_attn_weights = torch.exp(attn_weights - mc) if row_mask is not None: exp_attn_weights.masked_fill_(~row_mask, 0.0) p = exp_attn_weights / lc dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) D = (doc * oc).sum(dim=-1, keepdims=True) ds = p * scale * (dp - D) dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) dqc.add_(dq_chunk) dkc.add_(dk_chunk) dvc.add_(dv_chunk) return dq, dk, dv, None, None, None, None class FlashAttnProcessor: def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, ) -> Any: q_bucket_size = 512 k_bucket_size = 1024 h = attn.heads q = attn.to_q(hidden_states) encoder_hidden_states = ( encoder_hidden_states if encoder_hidden_states is not None else hidden_states ) encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None: context_k, context_v = attn.hypernetwork.forward( hidden_states, encoder_hidden_states ) context_k = context_k.to(hidden_states.dtype) context_v = context_v.to(hidden_states.dtype) else: context_k = encoder_hidden_states context_v = encoder_hidden_states k = attn.to_k(context_k) v = attn.to_v(context_v) del encoder_hidden_states, hidden_states q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) out = FlashAttentionFunction.apply( q, k, v, attention_mask, False, q_bucket_size, k_bucket_size ) out = rearrange(out, "b h n d -> b n (h d)") out = attn.to_out[0](out) out = attn.to_out[1](out) return out