|
from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
import torch |
|
|
|
|
|
def flash_attention2_forward_with_window_size( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
query_length, |
|
dropout=0.0, |
|
softmax_scale=None, |
|
window_size=[-1, -1], |
|
return_attn_probs=False, |
|
): |
|
""" |
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
|
first unpad the input, then computes the attention scores and pad the final attention scores. |
|
|
|
Args: |
|
query_states (`torch.Tensor`): |
|
Input query states to be passed to Flash Attention API |
|
key_states (`torch.Tensor`): |
|
Input key states to be passed to Flash Attention API |
|
value_states (`torch.Tensor`): |
|
Input value states to be passed to Flash Attention API |
|
attention_mask (`torch.Tensor`): |
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
|
position of padding tokens and 1 for the position of non-padding tokens. |
|
dropout (`int`, *optional*): |
|
Attention dropout |
|
softmax_scale (`float`, *optional*): |
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
|
window_size ([Int, Int]) |
|
The left & right window size for Flash Attention. Default to [-1, -1] which means no window size is used. |
|
return_attn_probs (`bool`, *optional*): |
|
Whether to return the attention softmax logssumexp and probabilities. Default to False. |
|
""" |
|
if not self._flash_attn_uses_top_left_mask: |
|
causal = self.is_causal |
|
else: |
|
|
|
causal = self.is_causal and query_length != 1 |
|
|
|
|
|
if attention_mask is not None: |
|
batch_size = query_states.shape[0] |
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
|
query_states, key_states, value_states, attention_mask, query_length |
|
) |
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
attn_output_unpad, softmax_lse, S_dmask = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
window_size=window_size, |
|
return_attn_probs=True, |
|
) |
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
|
else: |
|
attn_output, softmax_lse, S_dmask = flash_attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
dropout, |
|
softmax_scale=softmax_scale, |
|
causal=causal, |
|
window_size=window_size, |
|
return_attn_probs=True, |
|
) |
|
|
|
if return_attn_probs: |
|
return attn_output, softmax_lse, S_dmask |
|
else: |
|
return attn_output |
|
|
|
def self_extend_flash_forward( |
|
model_self, |
|
query_position, |
|
group_size_2, |
|
neighbor_query_states, |
|
neighbor_key_states, |
|
group_query_states, |
|
group_key_states, |
|
value_states, |
|
attention_mask, |
|
bsz, |
|
q_len, |
|
kv_seq_len, |
|
attn_dropout, |
|
): |
|
|
|
if query_position.max() >= group_size_2: |
|
neighbor_attn_output, neighbor_softmax_lse_right_padded, neighbor_prob = model_self._flash_attention_forward( |
|
neighbor_query_states, |
|
neighbor_key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
dropout=attn_dropout, |
|
window_size=[group_size_2 - 1, 0], |
|
|
|
return_attn_probs=True, |
|
) |
|
|
|
group_attention_len = ( |
|
kv_seq_len - group_size_2 |
|
) |
|
|
|
group_attention_mask = attention_mask[:, :group_attention_len] if not attention_mask is None else None |
|
group_attn_output, group_softmax_lse_right_padded, group_prob = model_self._flash_attention_forward( |
|
group_query_states[:, -group_attention_len:, :, :], |
|
group_key_states[:, :group_attention_len, :, :], |
|
value_states[:, :group_attention_len, :, :], |
|
group_attention_mask, |
|
group_query_states[:, -group_attention_len:, :, :].shape[1], |
|
dropout=attn_dropout, |
|
window_size=[-1, -1], |
|
return_attn_probs=True, |
|
) |
|
|
|
|
|
|
|
neighbor_seq_length = torch.Tensor([kv_seq_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask, axis=1, keepdim=True) |
|
group_seq_length = torch.Tensor([group_attention_len,]).long().expand(bsz, 1) if attention_mask is None else torch.sum(attention_mask[:, :group_attention_len], axis=1, keepdim=True) |
|
|
|
|
|
neighbor_softmax_lse = torch.zeros_like(neighbor_softmax_lse_right_padded) |
|
group_softmax_lse = torch.zeros_like(group_softmax_lse_right_padded) |
|
for idx in range(bsz): |
|
if neighbor_seq_length[idx] > 0: |
|
neighbor_softmax_lse[idx, :, -neighbor_seq_length[idx] :] = neighbor_softmax_lse_right_padded[ |
|
idx, :, : neighbor_seq_length[idx] |
|
] |
|
if group_seq_length[idx] > 0: |
|
group_softmax_lse[idx, :, -group_seq_length[idx] :] = group_softmax_lse_right_padded[ |
|
idx, :, : group_seq_length[idx] |
|
] |
|
|
|
|
|
true_neighbor_seq_max_length = neighbor_softmax_lse.shape[ |
|
-1 |
|
] |
|
true_group_seq_max_length = group_softmax_lse.shape[ |
|
-1 |
|
] |
|
|
|
neighbor_softmax_lse = neighbor_softmax_lse.transpose(1, 2).unsqueeze( |
|
-1 |
|
) |
|
group_softmax_lse = group_softmax_lse.transpose(1, 2).unsqueeze( |
|
-1 |
|
) |
|
|
|
lse_gap = group_softmax_lse - neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] |
|
|
|
|
|
|
|
neighbor_softmax_lse[:, -true_group_seq_max_length:, :, :] = 1 / (1 + torch.exp(lse_gap)) |
|
neighbor_softmax_lse[:, :-true_group_seq_max_length, :, :] = 1. |
|
group_softmax_lse = 1 / (1 + torch.exp(-lse_gap)) |
|
|
|
|
|
|
|
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] = ( |
|
neighbor_attn_output[:, -true_neighbor_seq_max_length:, ...] * neighbor_softmax_lse |
|
) |
|
group_attn_output[:, -true_group_seq_max_length:, ...] = ( |
|
group_attn_output[:, -true_group_seq_max_length:, ...] * group_softmax_lse |
|
) |
|
attn_output = torch.empty_like(neighbor_attn_output).copy_( |
|
neighbor_attn_output |
|
) |
|
|
|
attn_output[:, group_size_2-kv_seq_len:, ...] += group_attn_output |
|
attn_output = torch.nan_to_num(attn_output, nan=0) |
|
|
|
else: |
|
attn_output = model_self._flash_attention_forward( |
|
neighbor_query_states, |
|
neighbor_key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
dropout=attn_dropout, |
|
window_size=[-1, -1], |
|
) |
|
|
|
return attn_output |
|
|