# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union import paddle import paddle.nn as nn import paddle.nn.functional as F from ..initializer import normal_, zeros_ class CrossAttention(nn.Layer): r""" A cross attention layer. Parameters: query_dim (`int`): The number of channels in the query. cross_attention_dim (`int`, *optional*): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. bias (`bool`, *optional*, defaults to False): Set to `True` for the query, key, and value linear layers to contain a bias parameter. """ def __init__( self, query_dim: int, cross_attention_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias=False, upcast_attention: bool = False, upcast_softmax: bool = False, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, processor: Optional["AttnProcessor"] = None, ): super().__init__() inner_dim = dim_head * heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.scale = dim_head**-0.5 self.num_heads = heads self.head_dim = inner_dim // heads # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self.sliceable_head_dim = heads self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5) else: self.group_norm = None self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) if self.added_kv_proj_dim is not None: self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) self.to_out = nn.LayerList([]) self.to_out.append(nn.Linear(inner_dim, query_dim)) self.to_out.append(nn.Dropout(dropout)) # set attention processor processor = processor if processor is not None else CrossAttnProcessor() self.set_processor(processor) def set_attention_slice(self, slice_size): if slice_size is not None and slice_size > self.sliceable_head_dim: raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") if slice_size is not None and self.added_kv_proj_dim is not None: processor = SlicedAttnAddedKVProcessor(slice_size) elif slice_size is not None: processor = SlicedAttnProcessor(slice_size) elif self.added_kv_proj_dim is not None: processor = CrossAttnAddedKVProcessor() else: processor = CrossAttnProcessor() self.set_processor(processor) def set_processor(self, processor: "AttnProcessor"): self.processor = processor def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): # The `CrossAttention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty return self.processor( self, hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **cross_attention_kwargs, ) def batch_to_head_dim(self, tensor): tensor = tensor.transpose([0, 2, 1, 3]) tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]]) return tensor def head_to_batch_dim(self, tensor): tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim]) tensor = tensor.transpose([0, 2, 1, 3]) return tensor def get_attention_scores(self, query, key, attention_mask=None): if self.upcast_attention: query = query.cast("float32") key = key.cast("float32") attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale if attention_mask is not None: attention_scores = attention_scores + attention_mask if self.upcast_softmax: attention_scores = attention_scores.cast("float32") attention_probs = F.softmax(attention_scores, axis=-1) if self.upcast_softmax: attention_probs = attention_probs.cast(query.dtype) return attention_probs def prepare_attention_mask(self, attention_mask, target_length): if attention_mask is None: return attention_mask if attention_mask.shape[-1] != target_length: attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL") attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0) return attention_mask class CrossAttnProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = ( attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) if attention_mask is not None else None ) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = paddle.matmul(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class LoRALinearLayer(nn.Layer): def __init__(self, in_features, out_features, rank=4): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") self.down = nn.Linear(in_features, rank, bias_attr=False) self.up = nn.Linear(rank, out_features, bias_attr=False) self.scale = 1.0 normal_(self.down.weight, std=1 / rank) zeros_(self.up.weight) def forward(self, hidden_states): orig_dtype = hidden_states.dtype dtype = self.down.weight.dtype down_hidden_states = self.down(hidden_states.cast(dtype)) up_hidden_states = self.up(down_hidden_states) return up_hidden_states.cast(orig_dtype) class LoRACrossAttnProcessor(nn.Layer): def __init__(self, hidden_size, cross_attention_dim=None, rank=4): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__( self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = ( attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) if attention_mask is not None else None ) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = paddle.matmul(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class CrossAttnAddedKVProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( [0, 2, 1] ) batch_size, sequence_length, _ = hidden_states.shape encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) attention_mask = ( attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) if attention_mask is not None else None ) hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = paddle.matmul(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states class SlicedAttnProcessor: def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) query = query.flatten(0, 1) key = key.flatten(0, 1) value = value.flatten(0, 1) batch_size_attention = query.shape[0] hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) for i in range(hidden_states.shape[0] // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice # reshape back to [bs, num_heads, seqlen, head_dim] hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) # reshape hidden_states hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class SlicedAttnAddedKVProcessor: def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( [0, 2, 1] ) encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) query = query.flatten(0, 1) key = key.flatten(0, 1) value = value.flatten(0, 1) batch_size_attention = query.shape[0] hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) for i in range(hidden_states.shape[0] // self.slice_size): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice # reshape back to [bs, num_heads, seqlen, head_dim] hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) # reshape hidden_states hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) hidden_states = hidden_states + residual return hidden_states AttnProcessor = Union[ CrossAttnProcessor, SlicedAttnProcessor, CrossAttnAddedKVProcessor, SlicedAttnAddedKVProcessor, ]