lhallee commited on
Commit
e250ef1
·
verified ·
1 Parent(s): 66589ea

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +26 -2
modeling_esm_plusplus.py CHANGED
@@ -886,7 +886,7 @@ class TransformerStack(nn.Module):
886
  attn_backend: str = "sdpa",
887
  ):
888
  super().__init__()
889
- self.attn_backend = attn_backend
890
  self.blocks = nn.ModuleList(
891
  [
892
  UnifiedTransformerBlock(
@@ -901,6 +901,18 @@ class TransformerStack(nn.Module):
901
  )
902
  self.norm = nn.LayerNorm(d_model, bias=False)
903
  self.gradient_checkpointing = False
 
 
 
 
 
 
 
 
 
 
 
 
904
 
905
  def forward(
906
  self,
@@ -924,7 +936,7 @@ class TransformerStack(nn.Module):
924
 
925
  # move to 4D attention mask or flex block mask
926
  attention_mask, flex_block_mask = get_attention_mask(
927
- attn_backend=self.attn_backend,
928
  batch_size=x.shape[0],
929
  seq_len=x.shape[1],
930
  device=x.device,
@@ -997,6 +1009,18 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
997
  nn.init.zeros_(module.bias)
998
  nn.init.ones_(module.weight)
999
 
 
 
 
 
 
 
 
 
 
 
 
 
1000
  def _reset_rotary_embeddings(self):
1001
  """Refresh non-persistent rotary buffers after checkpoint loading."""
1002
  for module in self.modules():
 
886
  attn_backend: str = "sdpa",
887
  ):
888
  super().__init__()
889
+ self._attn_backend = attn_backend
890
  self.blocks = nn.ModuleList(
891
  [
892
  UnifiedTransformerBlock(
 
901
  )
902
  self.norm = nn.LayerNorm(d_model, bias=False)
903
  self.gradient_checkpointing = False
904
+ self.attn_backend = attn_backend
905
+
906
+ @property
907
+ def attn_backend(self) -> str:
908
+ return self._attn_backend
909
+
910
+ @attn_backend.setter
911
+ def attn_backend(self, backend: str) -> None:
912
+ assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
913
+ self._attn_backend = backend
914
+ for block in self.blocks:
915
+ block.attn.attn_backend = backend
916
 
917
  def forward(
918
  self,
 
936
 
937
  # move to 4D attention mask or flex block mask
938
  attention_mask, flex_block_mask = get_attention_mask(
939
+ attn_backend=self._attn_backend,
940
  batch_size=x.shape[0],
941
  seq_len=x.shape[1],
942
  device=x.device,
 
1009
  nn.init.zeros_(module.bias)
1010
  nn.init.ones_(module.weight)
1011
 
1012
+ @property
1013
+ def attn_backend(self) -> str:
1014
+ return self.config.attn_backend
1015
+
1016
+ @attn_backend.setter
1017
+ def attn_backend(self, backend: str) -> None:
1018
+ assert backend in ("sdpa", "flex"), f"Unsupported attn_backend: {backend}"
1019
+ self.config.attn_backend = backend
1020
+ for module in self.modules():
1021
+ if isinstance(module, TransformerStack):
1022
+ module.attn_backend = backend
1023
+
1024
  def _reset_rotary_embeddings(self):
1025
  """Refresh non-persistent rotary buffers after checkpoint loading."""
1026
  for module in self.modules():