Spaces:
Runtime error
Runtime error
# Copyright 2022 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional, Union | |
import paddle | |
import paddle.nn as nn | |
import paddle.nn.functional as F | |
from ..initializer import normal_, zeros_ | |
class CrossAttention(nn.Layer): | |
r""" | |
A cross attention layer. | |
Parameters: | |
query_dim (`int`): The number of channels in the query. | |
cross_attention_dim (`int`, *optional*): | |
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. | |
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. | |
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
bias (`bool`, *optional*, defaults to False): | |
Set to `True` for the query, key, and value linear layers to contain a bias parameter. | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
cross_attention_dim: Optional[int] = None, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias=False, | |
upcast_attention: bool = False, | |
upcast_softmax: bool = False, | |
added_kv_proj_dim: Optional[int] = None, | |
norm_num_groups: Optional[int] = None, | |
processor: Optional["AttnProcessor"] = None, | |
): | |
super().__init__() | |
inner_dim = dim_head * heads | |
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim | |
self.upcast_attention = upcast_attention | |
self.upcast_softmax = upcast_softmax | |
self.scale = dim_head**-0.5 | |
self.num_heads = heads | |
self.head_dim = inner_dim // heads | |
# for slice_size > 0 the attention score computation | |
# is split across the batch axis to save memory | |
# You can set slice_size with `set_attention_slice` | |
self.sliceable_head_dim = heads | |
self.added_kv_proj_dim = added_kv_proj_dim | |
if norm_num_groups is not None: | |
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, epsilon=1e-5) | |
else: | |
self.group_norm = None | |
self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=bias) | |
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) | |
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias_attr=bias) | |
if self.added_kv_proj_dim is not None: | |
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) | |
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) | |
self.to_out = nn.LayerList([]) | |
self.to_out.append(nn.Linear(inner_dim, query_dim)) | |
self.to_out.append(nn.Dropout(dropout)) | |
# set attention processor | |
processor = processor if processor is not None else CrossAttnProcessor() | |
self.set_processor(processor) | |
def set_attention_slice(self, slice_size): | |
if slice_size is not None and slice_size > self.sliceable_head_dim: | |
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") | |
if slice_size is not None and self.added_kv_proj_dim is not None: | |
processor = SlicedAttnAddedKVProcessor(slice_size) | |
elif slice_size is not None: | |
processor = SlicedAttnProcessor(slice_size) | |
elif self.added_kv_proj_dim is not None: | |
processor = CrossAttnAddedKVProcessor() | |
else: | |
processor = CrossAttnProcessor() | |
self.set_processor(processor) | |
def set_processor(self, processor: "AttnProcessor"): | |
self.processor = processor | |
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): | |
# The `CrossAttention` class can call different attention processors / attention functions | |
# here we simply pass along all tensors to the selected processor class | |
# For standard processors that are defined here, `**cross_attention_kwargs` is empty | |
return self.processor( | |
self, | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
attention_mask=attention_mask, | |
**cross_attention_kwargs, | |
) | |
def batch_to_head_dim(self, tensor): | |
tensor = tensor.transpose([0, 2, 1, 3]) | |
tensor = tensor.reshape([0, 0, tensor.shape[2] * tensor.shape[3]]) | |
return tensor | |
def head_to_batch_dim(self, tensor): | |
tensor = tensor.reshape([0, 0, self.num_heads, self.head_dim]) | |
tensor = tensor.transpose([0, 2, 1, 3]) | |
return tensor | |
def get_attention_scores(self, query, key, attention_mask=None): | |
if self.upcast_attention: | |
query = query.cast("float32") | |
key = key.cast("float32") | |
attention_scores = paddle.matmul(query, key, transpose_y=True) * self.scale | |
if attention_mask is not None: | |
attention_scores = attention_scores + attention_mask | |
if self.upcast_softmax: | |
attention_scores = attention_scores.cast("float32") | |
attention_probs = F.softmax(attention_scores, axis=-1) | |
if self.upcast_softmax: | |
attention_probs = attention_probs.cast(query.dtype) | |
return attention_probs | |
def prepare_attention_mask(self, attention_mask, target_length): | |
if attention_mask is None: | |
return attention_mask | |
if attention_mask.shape[-1] != target_length: | |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0, data_format="NCL") | |
attention_mask = attention_mask.repeat_interleave(self.num_heads, axis=0) | |
return attention_mask | |
class CrossAttnProcessor: | |
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | |
attention_mask = ( | |
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) | |
if attention_mask is not None | |
else None | |
) | |
query = attn.to_q(hidden_states) | |
query = attn.head_to_batch_dim(query) | |
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = paddle.matmul(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
class LoRALinearLayer(nn.Layer): | |
def __init__(self, in_features, out_features, rank=4): | |
super().__init__() | |
if rank > min(in_features, out_features): | |
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") | |
self.down = nn.Linear(in_features, rank, bias_attr=False) | |
self.up = nn.Linear(rank, out_features, bias_attr=False) | |
self.scale = 1.0 | |
normal_(self.down.weight, std=1 / rank) | |
zeros_(self.up.weight) | |
def forward(self, hidden_states): | |
orig_dtype = hidden_states.dtype | |
dtype = self.down.weight.dtype | |
down_hidden_states = self.down(hidden_states.cast(dtype)) | |
up_hidden_states = self.up(down_hidden_states) | |
return up_hidden_states.cast(orig_dtype) | |
class LoRACrossAttnProcessor(nn.Layer): | |
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.rank = rank | |
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | |
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | |
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | |
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | |
def __call__( | |
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 | |
): | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | |
attention_mask = ( | |
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) | |
if attention_mask is not None | |
else None | |
) | |
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | |
query = attn.head_to_batch_dim(query) | |
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | |
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = paddle.matmul(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
class CrossAttnAddedKVProcessor: | |
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | |
residual = hidden_states | |
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( | |
[0, 2, 1] | |
) | |
batch_size, sequence_length, _ = hidden_states.shape | |
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | |
attention_mask = ( | |
attention_mask.reshape([batch_size, attn.num_heads, -1, attention_mask.shape[-1]]) | |
if attention_mask is not None | |
else None | |
) | |
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) | |
query = attn.to_q(hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) | |
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) | |
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) | |
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = paddle.matmul(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class SlicedAttnProcessor: | |
def __init__(self, slice_size): | |
self.slice_size = slice_size | |
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | |
query = attn.to_q(hidden_states) | |
query = attn.head_to_batch_dim(query) | |
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
query = query.flatten(0, 1) | |
key = key.flatten(0, 1) | |
value = value.flatten(0, 1) | |
batch_size_attention = query.shape[0] | |
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) | |
for i in range(hidden_states.shape[0] // self.slice_size): | |
start_idx = i * self.slice_size | |
end_idx = (i + 1) * self.slice_size | |
query_slice = query[start_idx:end_idx] | |
key_slice = key[start_idx:end_idx] | |
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None | |
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) | |
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) | |
hidden_states[start_idx:end_idx] = attn_slice | |
# reshape back to [bs, num_heads, seqlen, head_dim] | |
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) | |
# reshape hidden_states | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
class SlicedAttnAddedKVProcessor: | |
def __init__(self, slice_size): | |
self.slice_size = slice_size | |
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): | |
residual = hidden_states | |
hidden_states = hidden_states.reshape([hidden_states.shape[0], hidden_states.shape[1], -1]).transpose( | |
[0, 2, 1] | |
) | |
encoder_hidden_states = encoder_hidden_states.transpose([0, 2, 1]) | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | |
hidden_states = attn.group_norm(hidden_states.transpose([0, 2, 1])).transpose([0, 2, 1]) | |
query = attn.to_q(hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.to_k(hidden_states) | |
value = attn.to_v(hidden_states) | |
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) | |
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) | |
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) | |
key = paddle.concat([encoder_hidden_states_key_proj, key], axis=2) | |
value = paddle.concat([encoder_hidden_states_value_proj, value], axis=2) | |
query = query.flatten(0, 1) | |
key = key.flatten(0, 1) | |
value = value.flatten(0, 1) | |
batch_size_attention = query.shape[0] | |
hidden_states = paddle.zeros((batch_size_attention, sequence_length, attn.head_dim), dtype=query.dtype) | |
for i in range(hidden_states.shape[0] // self.slice_size): | |
start_idx = i * self.slice_size | |
end_idx = (i + 1) * self.slice_size | |
query_slice = query[start_idx:end_idx] | |
key_slice = key[start_idx:end_idx] | |
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None | |
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) | |
attn_slice = paddle.matmul(attn_slice, value[start_idx:end_idx]) | |
hidden_states[start_idx:end_idx] = attn_slice | |
# reshape back to [bs, num_heads, seqlen, head_dim] | |
hidden_states = hidden_states.reshape([-1, attn.num_heads, sequence_length, attn.head_dim]) | |
# reshape hidden_states | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
hidden_states = hidden_states.transpose([0, 2, 1]).reshape(residual.shape) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
AttnProcessor = Union[ | |
CrossAttnProcessor, | |
SlicedAttnProcessor, | |
CrossAttnAddedKVProcessor, | |
SlicedAttnAddedKVProcessor, | |
] | |