# original source: # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py # license: # MIT # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 from functools import partial import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math try: from typing import Optional, NamedTuple, List, Protocol except ImportError: from typing import Optional, NamedTuple, List from typing_extensions import Protocol from torch import Tensor from typing import List from comfy import model_management def dynamic_slice( x: Tensor, starts: List[int], sizes: List[int], ) -> Tensor: slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] return x[slicing] class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor class SummarizeChunk(Protocol): @staticmethod def __call__( query: Tensor, key_t: Tensor, value: Tensor, ) -> AttnChunk: ... class ComputeQueryChunkAttn(Protocol): @staticmethod def __call__( query: Tensor, key_t: Tensor, value: Tensor, ) -> Tensor: ... def _summarize_chunk( query: Tensor, key_t: Tensor, value: Tensor, scale: float, upcast_attention: bool, ) -> AttnChunk: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): query = query.float() key_t = key_t.float() attn_weights = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, key_t, alpha=scale, beta=0, ) else: attn_weights = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, key_t, alpha=scale, beta=0, ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() torch.exp(attn_weights - max_score, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) def _query_chunk_attention( query: Tensor, key_t: Tensor, value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: key_chunk = dynamic_slice( key_t, (0, 0, chunk_idx), (batch_x_heads, k_channels_per_head, kv_chunk_size) ) value_chunk = dynamic_slice( value, (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) return summarize_chunk(query, key_chunk, value_chunk) chunks: List[AttnChunk] = [ chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk global_max, _ = torch.max(chunk_max, 0, keepdim=True) max_diffs = torch.exp(chunk_max - global_max) chunk_values *= torch.unsqueeze(max_diffs, -1) chunk_weights *= max_diffs all_values = chunk_values.sum(dim=0) all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, key_t: Tensor, value: Tensor, scale: float, upcast_attention: bool, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): query = query.float() key_t = key_t.float() attn_scores = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, key_t, alpha=scale, beta=0, ) else: attn_scores = torch.baddbmm( torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), query, key_t, alpha=scale, beta=0, ) try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores except model_management.OOM_EXCEPTION: print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values torch.exp(attn_scores, out=attn_scores) summed = torch.sum(attn_scores, dim=-1, keepdim=True) attn_scores /= summed attn_probs = attn_scores hidden_states_slice = torch.bmm(attn_probs.to(value.dtype), value) return hidden_states_slice class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk def efficient_dot_product_attention( query: Tensor, key_t: Tensor, value: Tensor, query_chunk_size=1024, kv_chunk_size: Optional[int] = None, kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. Args: query: queries for calculating attention with shape of `[batch * num_heads, tokens, channels_per_head]`. key_t: keys for calculating attention with shape of `[batch * num_heads, channels_per_head, tokens]`. value: values to be used in attention with shape of `[batch * num_heads, tokens, channels_per_head]`. query_chunk_size: int: query chunks size kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) Returns: Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. """ batch_x_heads, q_tokens, q_channels_per_head = query.shape _, _, k_tokens = key_t.shape scale = q_channels_per_head ** -0.5 kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) if kv_chunk_size_min is not None: kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( _get_attention_scores_no_kv_chunking, scale=scale, upcast_attention=upcast_attention ) if k_tokens <= kv_chunk_size else ( # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) partial( _query_chunk_attention, kv_chunk_size=kv_chunk_size, summarize_chunk=summarize_chunk, ) ) if q_tokens <= query_chunk_size: # fast-path for when there's just 1 query chunk return compute_query_chunk_attn( query=query, key_t=key_t, value=value, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # and pass slices to be mutated, instead of torch.cat()ing the returned slices res = torch.cat([ compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res