lhallee commited on
Commit
f175f2d
·
verified ·
1 Parent(s): 7a27746

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +227 -37
modeling_fastesm.py CHANGED
@@ -1,15 +1,14 @@
 
1
  import torch
2
  import torch.nn as nn
3
- import os
4
  import warnings
5
- import networkx as nx
6
  from torch.nn import functional as F
7
- from torch.utils.data import Dataset as TorchDataset
8
- from torch.utils.data import DataLoader as DataLoader
9
- from typing import Optional, Tuple, Union, Callable, List, Dict, Any
10
  from einops import rearrange
11
  from dataclasses import dataclass
12
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer, PreTrainedTokenizerBase
13
  from transformers.modeling_outputs import (
14
  ModelOutput,
15
  BaseModelOutputWithPastAndCrossAttentions,
@@ -25,29 +24,11 @@ from transformers.models.esm.modeling_esm import (
25
  EsmSelfOutput,
26
  EsmClassificationHead,
27
  )
28
- from tqdm.auto import tqdm
29
- from embedding_mixin import EmbeddingMixin, Pooler
30
 
31
- try:
32
- from torch.nn.attention.flex_attention import create_block_mask
33
- from torch.nn.attention.flex_attention import flex_attention as _raw_flex_attention
34
- except ImportError:
35
- create_block_mask = None
36
- _raw_flex_attention = None
37
 
38
 
39
- def _resolve_flex_attention(attn_compile: bool):
40
- if _raw_flex_attention is None:
41
- return None
42
- if not attn_compile:
43
- return _raw_flex_attention
44
- try:
45
- return torch.compile(_raw_flex_attention, dynamic=True)
46
- except Exception:
47
- return _raw_flex_attention
48
-
49
-
50
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
51
  assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
52
  token_valid = attention_mask_2d.bool()
53
  batch_size, seq_len = token_valid.shape
@@ -62,7 +43,6 @@ def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
62
  seq_len,
63
  seq_len,
64
  device=attention_mask_2d.device,
65
- BLOCK_SIZE=block_size,
66
  )
67
 
68
 
@@ -94,9 +74,7 @@ class FastEsmConfig(PretrainedConfig):
94
  position_embedding_type: str = "absolute",
95
  emb_layer_norm_before: bool = None,
96
  token_dropout: bool = True,
97
- attn_backend: str = "flex",
98
- attn_compile: bool = True,
99
- flex_block_size: int = 128,
100
  **kwargs,
101
  ):
102
  super().__init__(
@@ -120,8 +98,6 @@ class FastEsmConfig(PretrainedConfig):
120
  self.tie_word_embeddings = False
121
  self.token_dropout = token_dropout
122
  self.attn_backend = attn_backend
123
- self.attn_compile = attn_compile
124
- self.flex_block_size = flex_block_size
125
 
126
  def to_dict(self) -> Dict[str, Any]:
127
  """
@@ -329,8 +305,6 @@ class EsmSelfAttention(nn.Module):
329
 
330
  self.dropout_prob = config.attention_probs_dropout_prob
331
  self.attn_backend = config.attn_backend
332
- self.flex_block_size = config.flex_block_size
333
- self.flex_attention = _resolve_flex_attention(config.attn_compile)
334
  self._warned_flex_fallback = False
335
  self.position_embedding_type = position_embedding_type or getattr(
336
  config, "position_embedding_type", "absolute"
@@ -384,18 +358,16 @@ class EsmSelfAttention(nn.Module):
384
  sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
385
  use_flex = (
386
  self.attn_backend == "flex"
387
- and self.flex_attention is not None
388
  and (attention_mask is None or flex_block_mask is not None)
389
  )
390
  if use_flex:
391
  try:
392
- context_layer = self.flex_attention(
393
  query_layer,
394
  key_layer,
395
  value_layer,
396
  block_mask=flex_block_mask,
397
  scale=1.0,
398
- enable_gqa=query_layer.shape[1] != key_layer.shape[1],
399
  )
400
  except Exception as exc:
401
  if not self._warned_flex_fallback:
@@ -586,6 +558,224 @@ class EsmEncoder(nn.Module):
586
  )
587
 
588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
590
  def __init__(self, config, **kwargs):
591
  FastEsmPreTrainedModel.__init__(self, config, **kwargs)
 
1
+ import entrypoint_setup
2
  import torch
3
  import torch.nn as nn
 
4
  import warnings
 
5
  from torch.nn import functional as F
6
+ from torch.nn.attention.flex_attention import create_block_mask
7
+ from torch.nn.attention.flex_attention import flex_attention
8
+ from typing import Optional, Tuple, Union, Dict, Any
9
  from einops import rearrange
10
  from dataclasses import dataclass
11
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
12
  from transformers.modeling_outputs import (
13
  ModelOutput,
14
  BaseModelOutputWithPastAndCrossAttentions,
 
24
  EsmSelfOutput,
25
  EsmClassificationHead,
26
  )
 
 
27
 
28
+ from .embedding_mixin import EmbeddingMixin
 
 
 
 
 
29
 
30
 
31
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
32
  assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
33
  token_valid = attention_mask_2d.bool()
34
  batch_size, seq_len = token_valid.shape
 
43
  seq_len,
44
  seq_len,
45
  device=attention_mask_2d.device,
 
46
  )
47
 
48
 
 
74
  position_embedding_type: str = "absolute",
75
  emb_layer_norm_before: bool = None,
76
  token_dropout: bool = True,
77
+ attn_backend: str = "sdpa",
 
 
78
  **kwargs,
79
  ):
80
  super().__init__(
 
98
  self.tie_word_embeddings = False
99
  self.token_dropout = token_dropout
100
  self.attn_backend = attn_backend
 
 
101
 
102
  def to_dict(self) -> Dict[str, Any]:
103
  """
 
305
 
306
  self.dropout_prob = config.attention_probs_dropout_prob
307
  self.attn_backend = config.attn_backend
 
 
308
  self._warned_flex_fallback = False
309
  self.position_embedding_type = position_embedding_type or getattr(
310
  config, "position_embedding_type", "absolute"
 
358
  sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
359
  use_flex = (
360
  self.attn_backend == "flex"
 
361
  and (attention_mask is None or flex_block_mask is not None)
362
  )
363
  if use_flex:
364
  try:
365
+ context_layer = flex_attention(
366
  query_layer,
367
  key_layer,
368
  value_layer,
369
  block_mask=flex_block_mask,
370
  scale=1.0,
 
371
  )
372
  except Exception as exc:
373
  if not self._warned_flex_fallback:
 
558
  )
559
 
560
 
561
+ class FastEsmPreTrainedModel(PreTrainedModel):
562
+ """
563
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
564
+ models.
565
+ """
566
+ config_class = FastEsmConfig
567
+ base_model_prefix = "fastesm"
568
+ supports_gradient_checkpointing = True
569
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
570
+ all_tied_weights_keys = {}
571
+
572
+ def _init_weights(self, module):
573
+ """Initialize the weights"""
574
+ if isinstance(module, nn.Linear):
575
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
576
+ if module.bias is not None:
577
+ module.bias.data.zero_()
578
+ elif isinstance(module, nn.Embedding):
579
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
580
+ if module.padding_idx is not None:
581
+ module.weight.data[module.padding_idx].zero_()
582
+ elif isinstance(module, nn.LayerNorm):
583
+ if module.bias is not None:
584
+ module.bias.data.zero_()
585
+ module.weight.data.fill_(1.0)
586
+
587
+ def get_input_embeddings(self) -> nn.Module:
588
+ try:
589
+ return self.embeddings.word_embeddings
590
+ except AttributeError:
591
+ return self.esm.embeddings.word_embeddings
592
+
593
+
594
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
595
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
596
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
597
+ self.config = config
598
+ self.embeddings = EsmEmbeddings(config)
599
+ self.encoder = EsmEncoder(config)
600
+ self.contact_head = EsmContactPredictionHead(
601
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
602
+ )
603
+ # Initialize weights and apply final processing
604
+ self.post_init()
605
+
606
+ def get_input_embeddings(self):
607
+ return self.embeddings.word_embeddings
608
+
609
+ def set_input_embeddings(self, value):
610
+ self.embeddings.word_embeddings = value
611
+
612
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
613
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
614
+ batch_size, seq_length = input_ids.shape
615
+ if attention_mask is not None:
616
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
617
+ batch_size, 1, seq_length, seq_length
618
+ ).bool()
619
+ else:
620
+ extended_attention_mask = None
621
+ encoder_outputs = self.encoder(
622
+ token_embedding_output,
623
+ attention_mask=extended_attention_mask,
624
+ output_hidden_states=False,
625
+ output_attentions=False,
626
+ )
627
+ return encoder_outputs.last_hidden_state
628
+
629
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
630
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
631
+ attns = torch.stack(attns, dim=1)
632
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
633
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
634
+ return self.contact_head(input_ids, attns)
635
+
636
+ def forward(
637
+ self,
638
+ input_ids: Optional[torch.Tensor] = None,
639
+ attention_mask: Optional[torch.Tensor] = None,
640
+ position_ids: Optional[torch.Tensor] = None,
641
+ inputs_embeds: Optional[torch.Tensor] = None,
642
+ output_attentions: Optional[bool] = None,
643
+ output_hidden_states: Optional[bool] = None,
644
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
645
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
646
+ """Forward pass for base model.
647
+
648
+ Args:
649
+ input_ids: Input token IDs
650
+ attention_mask: Optional attention mask
651
+ position_ids: Optional position IDs
652
+ inputs_embeds: Optional input embeddings
653
+ output_hidden_states: Whether to return all hidden states
654
+ output_attentions: Whether to return attention weights
655
+
656
+ Returns:
657
+ Model outputs including hidden states and optionally attention weights
658
+ """
659
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
660
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
661
+
662
+ if input_ids is not None and inputs_embeds is not None:
663
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
664
+ elif input_ids is not None:
665
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
666
+ input_shape = input_ids.size()
667
+ elif inputs_embeds is not None:
668
+ input_shape = inputs_embeds.size()[:-1]
669
+ else:
670
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
671
+
672
+ batch_size, seq_length = input_shape
673
+ token_embedding_output = self.embeddings(
674
+ input_ids=input_ids,
675
+ position_ids=position_ids,
676
+ attention_mask=attention_mask,
677
+ inputs_embeds=inputs_embeds,
678
+ )
679
+
680
+ if attention_mask is not None:
681
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
682
+ batch_size, 1, seq_length, seq_length
683
+ ).bool()
684
+ else:
685
+ extended_attention_mask = None
686
+
687
+ encoder_outputs = self.encoder(
688
+ token_embedding_output,
689
+ attention_mask=extended_attention_mask,
690
+ output_hidden_states=output_hidden_states,
691
+ output_attentions=output_attentions,
692
+ )
693
+ sequence_output = encoder_outputs.last_hidden_state
694
+
695
+ return BaseModelOutputWithPoolingAndCrossAttentions(
696
+ last_hidden_state=sequence_output,
697
+ hidden_states=encoder_outputs.hidden_states,
698
+ attentions=encoder_outputs.attentions,
699
+ )
700
+
701
+
702
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
703
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True, **kwargs):
704
+ FastEsmPreTrainedModel.__init__(self, config, **kwargs)
705
+ self.config = config
706
+ self.esm = FAST_ESM_ENCODER(config)
707
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
708
+ # Initialize weights and apply final processing
709
+ self.post_init()
710
+
711
+ def get_input_embeddings(self):
712
+ return self.embeddings.word_embeddings
713
+
714
+ def set_input_embeddings(self, value):
715
+ self.embeddings.word_embeddings = value
716
+
717
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
718
+ return self.esm._embed(input_ids, attention_mask)
719
+
720
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
721
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
722
+
723
+ def forward(
724
+ self,
725
+ input_ids: Optional[torch.Tensor] = None,
726
+ attention_mask: Optional[torch.Tensor] = None,
727
+ position_ids: Optional[torch.Tensor] = None,
728
+ inputs_embeds: Optional[torch.Tensor] = None,
729
+ output_attentions: Optional[bool] = None,
730
+ output_hidden_states: Optional[bool] = None,
731
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
732
+ **kwargs,
733
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
734
+ """Forward pass for base model.
735
+
736
+ Args:
737
+ input_ids: Input token IDs
738
+ attention_mask: Optional attention mask
739
+ position_ids: Optional position IDs
740
+ inputs_embeds: Optional input embeddings
741
+ output_hidden_states: Whether to return all hidden states
742
+ output_attentions: Whether to return attention weights
743
+
744
+ Returns:
745
+ Model outputs including hidden states and optionally attention weights
746
+ """
747
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
748
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
749
+
750
+ if input_ids is not None and inputs_embeds is not None:
751
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
752
+ elif input_ids is not None:
753
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
754
+ input_shape = input_ids.size()
755
+ elif inputs_embeds is not None:
756
+ input_shape = inputs_embeds.size()[:-1]
757
+ else:
758
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
759
+
760
+ outputs = self.esm(
761
+ input_ids,
762
+ attention_mask=attention_mask,
763
+ position_ids=position_ids,
764
+ inputs_embeds=inputs_embeds,
765
+ output_hidden_states=output_hidden_states,
766
+ output_attentions=output_attentions,
767
+ )
768
+ sequence_output = outputs.last_hidden_state
769
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
770
+
771
+ return BaseModelOutputWithPoolingAndCrossAttentions(
772
+ last_hidden_state=sequence_output,
773
+ pooler_output=pooled_output,
774
+ hidden_states=outputs.hidden_states,
775
+ attentions=outputs.attentions,
776
+ )
777
+
778
+
779
  class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
780
  def __init__(self, config, **kwargs):
781
  FastEsmPreTrainedModel.__init__(self, config, **kwargs)