| import torch |
| import logging |
| import torch.nn as nn |
| import transformers |
| from flash_attn.bert_padding import unpad_input, pad_input |
| from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func |
| from einops import rearrange |
| from typing import List, Optional, Tuple, Union |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor]: |
| mixed_query_layer = self.query(hidden_states) |
| assert encoder_hidden_states is None, "Cross-attention is not supported for ESM" |
| assert past_key_value is None, "Past key value is not supported for ESM" |
| assert self.is_decoder is False, "Decoder is not supported for ESM" |
| assert self.position_embedding_type == "rotary", "Rotary embeddings are required for ESM" |
| assert head_mask is None, "Head mask is not supported for ESM" |
| assert output_attentions is False, "Output attentions is not supported for ESM" |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
| query_layer = query_layer * self.attention_head_size**-0.5 |
|
|
| if self.position_embedding_type == "rotary": |
| query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) |
| |
| |
| qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) |
| qkv = qkv.transpose(1,3) |
| assert attention_mask is not None |
| key_padding_mask = attention_mask |
| bsz, q_len, _ = hidden_states.size() |
| nheads = qkv.shape[-2] |
| x = rearrange(qkv, "b s three h d -> b s (three h d)") |
| x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) |
| x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) |
| x_unpad = x_unpad.to(torch.bfloat16) |
| output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, max_s, self.dropout.p if self.training else 0.0, softmax_scale=1, causal=False) |
| if False: |
| outputs = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len), "b s (h d) -> b s h d", h=nheads) |
| outputs = rearrange(outputs, "b s h d -> b s (h d)") |
| else: |
| outputs = pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len) |
| return (outputs,) |
|
|
|
|
| def get_extended_attention_mask( |
| self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None |
| ) -> torch.Tensor: |
| return attention_mask |
|
|
|
|
| def forward_original( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, |
| past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| output_attentions: Optional[bool] = False, |
| ) -> Tuple[torch.Tensor]: |
| mixed_query_layer = self.query(hidden_states) |
|
|
| |
| |
| |
| is_cross_attention = encoder_hidden_states is not None |
|
|
| if is_cross_attention and past_key_value is not None: |
| |
| key_layer = past_key_value[0] |
| value_layer = past_key_value[1] |
| attention_mask = encoder_attention_mask |
| elif is_cross_attention: |
| key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| attention_mask = encoder_attention_mask |
| elif past_key_value is not None: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
| key_layer = torch.cat([past_key_value[0], key_layer], dim=2) |
| value_layer = torch.cat([past_key_value[1], value_layer], dim=2) |
| else: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
|
|
| |
| |
| |
| |
| query_layer = query_layer * self.attention_head_size**-0.5 |
|
|
| if self.is_decoder: |
| |
| |
| |
| |
| |
| |
| |
| past_key_value = (key_layer, value_layer) |
|
|
| if self.position_embedding_type == "rotary": |
| query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) |
|
|
| |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
| seq_length = hidden_states.size()[1] |
| position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
| position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
| distance = position_ids_l - position_ids_r |
| positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
| positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
| if self.position_embedding_type == "relative_key": |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores |
| elif self.position_embedding_type == "relative_key_query": |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
| attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
| if attention_mask is not None: |
| |
| attention_scores = attention_scores + attention_mask |
|
|
| |
| attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
| |
| |
| attention_probs = self.dropout(attention_probs) |
|
|
| |
| if head_mask is not None: |
| attention_probs = attention_probs * head_mask |
|
|
| context_layer = torch.matmul(attention_probs, value_layer) |
|
|
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| context_layer = context_layer.view(new_context_layer_shape) |
|
|
| outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
| if self.is_decoder: |
| outputs = outputs + (past_key_value,) |
| return outputs |
|
|
|
|
| def get_extended_attention_mask_original( |
| self, attention_mask: torch.Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None |
| ) -> torch.Tensor: |
| """ |
| Makes broadcastable attention and causal masks so that future and masked tokens are ignored. |
| |
| Arguments: |
| attention_mask (`torch.Tensor`): |
| Mask with ones indicating tokens to attend to, zeros for tokens to ignore. |
| input_shape (`Tuple[int]`): |
| The shape of the input to the model. |
| |
| Returns: |
| `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. |
| """ |
| if dtype is None: |
| dtype = self.dtype |
|
|
| if not (attention_mask.dim() == 2 and self.config.is_decoder): |
| |
| if device is not None: |
| print( |
| "The `device` argument is deprecated and will be removed in v5 of Transformers." |
| ) |
| |
| |
| if attention_mask.dim() == 3: |
| extended_attention_mask = attention_mask[:, None, :, :] |
| elif attention_mask.dim() == 2: |
| |
| |
| |
| if self.config.is_decoder: |
| extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( |
| input_shape, attention_mask, device |
| ) |
| else: |
| extended_attention_mask = attention_mask[:, None, None, :] |
| else: |
| raise ValueError( |
| f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" |
| ) |
|
|
| |
| |
| |
| |
| |
| extended_attention_mask = extended_attention_mask.to(dtype=dtype) |
| extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min |
| return extended_attention_mask |
|
|
|
|
| def replace_esm_attn_with_flash_attn(): |
| cuda_major, cuda_minor = torch.cuda.get_device_capability() |
| if cuda_major < 8: |
| logging.warning( |
| "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." |
| "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" |
| ) |
| |
| transformers.models.esm.modeling_esm.EsmModel.get_extended_attention_mask = get_extended_attention_mask |
| transformers.models.esm.modeling_esm.EsmSelfAttention.forward = forward |
|
|
|
|
| def replace_flash_attn_with_esm_attn(): |
| cuda_major, cuda_minor = torch.cuda.get_device_capability() |
| if cuda_major < 8: |
| logging.warning( |
| "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." |
| "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" |
| ) |
| transformers.models.esm.modeling_esm.EsmModel.get_extended_attention_mask = get_extended_attention_mask_original |
| transformers.models.esm.modeling_esm.EsmSelfAttention.forward = forward_original |
|
|
| if __name__ == '__main__': |
| pass |