Transformers
Safetensors
dplm2
custom_code
lhallee commited on
Commit
584971e
·
verified ·
1 Parent(s): ee41aa8

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +18 -16
modeling_dplm2.py CHANGED
@@ -419,9 +419,9 @@ def get_attention_mask(
419
  attention_mask: Optional[torch.Tensor] = None,
420
  ) -> Tuple[Optional[torch.Tensor], Optional[object]]:
421
  if attention_mask is None:
422
- token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
423
  else:
424
- token_attention_mask = attention_mask.bool()
425
 
426
  if attn_backend == "flex":
427
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
@@ -429,8 +429,10 @@ def get_attention_mask(
429
  if attention_mask is None:
430
  flex_block_mask = None
431
  else:
 
 
432
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
433
- return (token_attention_mask[batch_idx, q_idx] == token_attention_mask[batch_idx, kv_idx]) & (token_attention_mask[batch_idx, q_idx] != 0)
434
 
435
  flex_block_mask = create_block_mask(
436
  mask_mod,
@@ -440,12 +442,12 @@ def get_attention_mask(
440
  seq_len,
441
  device=device,
442
  )
443
- extended_attention_mask = None
444
  else:
445
  flex_block_mask = None
446
- extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
447
 
448
- return extended_attention_mask, flex_block_mask
449
 
450
 
451
  def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
@@ -989,12 +991,12 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
989
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
990
 
991
  if attention_mask is None:
992
- token_attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
993
  elif attention_mask.dim() == 2:
994
- token_attention_mask = attention_mask.bool()
995
  elif attention_mask.dim() == 4:
996
  assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
997
- token_attention_mask = input_ids.ne(self.config.pad_token_id)
998
  else:
999
  raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
1000
 
@@ -1009,19 +1011,19 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
1009
 
1010
  head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1011
 
1012
- embedding_attention_mask = token_attention_mask
1013
  if embedding_attention_mask is None and input_ids is not None:
1014
  embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
1015
 
1016
  if self.config.attn_backend == "flex" and output_attentions:
1017
  raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
1018
 
1019
- extended_attention_mask, flex_block_mask = get_attention_mask(
1020
  attn_backend=self.config.attn_backend,
1021
  batch_size=batch_size,
1022
  seq_len=seq_length,
1023
  device=device,
1024
- attention_mask=token_attention_mask,
1025
  )
1026
 
1027
  embedding_output = self.embeddings(
@@ -1032,7 +1034,7 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
1032
  )
1033
  encoder_outputs = self.encoder(
1034
  embedding_output,
1035
- attention_mask=extended_attention_mask,
1036
  head_mask=head_mask,
1037
  encoder_hidden_states=encoder_hidden_states,
1038
  encoder_attention_mask=encoder_extended_attention_mask,
@@ -1134,7 +1136,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1134
  if vocab_size is not None:
1135
  config.vocab_size = vocab_size
1136
  DPLM2PreTrainedModel.__init__(self, config)
1137
- self.esm = DPLM2Model(config, add_pooling_layer=False)
1138
  self.lm_head = EsmLMHead(config)
1139
  self.loss_fct = nn.CrossEntropyLoss()
1140
  self.post_init()
@@ -1235,7 +1237,7 @@ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1235
  def __init__(self, config):
1236
  DPLM2PreTrainedModel.__init__(self, config)
1237
  self.num_labels = config.num_labels
1238
- self.esm = DPLM2Model(config, add_pooling_layer=False)
1239
  self.classifier = EsmClassificationHead(config)
1240
  self.mse = nn.MSELoss()
1241
  self.ce = nn.CrossEntropyLoss()
@@ -1312,7 +1314,7 @@ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1312
  def __init__(self, config):
1313
  DPLM2PreTrainedModel.__init__(self, config)
1314
  self.num_labels = config.num_labels
1315
- self.esm = DPLM2Model(config, add_pooling_layer=False)
1316
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
1317
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1318
  self.loss_fct = nn.CrossEntropyLoss()
 
419
  attention_mask: Optional[torch.Tensor] = None,
420
  ) -> Tuple[Optional[torch.Tensor], Optional[object]]:
421
  if attention_mask is None:
422
+ attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
423
  else:
424
+ attention_mask_2d = attention_mask.bool()
425
 
426
  if attn_backend == "flex":
427
  assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
 
429
  if attention_mask is None:
430
  flex_block_mask = None
431
  else:
432
+ valid_lens = attention_mask_2d.sum(dim=-1)
433
+
434
  def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
435
+ return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
436
 
437
  flex_block_mask = create_block_mask(
438
  mask_mod,
 
442
  seq_len,
443
  device=device,
444
  )
445
+ attention_mask_4d = None
446
  else:
447
  flex_block_mask = None
448
+ attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
449
 
450
+ return attention_mask_4d, flex_block_mask
451
 
452
 
453
  def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 
991
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
992
 
993
  if attention_mask is None:
994
+ attention_mask_2d = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
995
  elif attention_mask.dim() == 2:
996
+ attention_mask_2d = attention_mask.bool()
997
  elif attention_mask.dim() == 4:
998
  assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
999
+ attention_mask_2d = input_ids.ne(self.config.pad_token_id)
1000
  else:
1001
  raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
1002
 
 
1011
 
1012
  head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1013
 
1014
+ embedding_attention_mask = attention_mask_2d
1015
  if embedding_attention_mask is None and input_ids is not None:
1016
  embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
1017
 
1018
  if self.config.attn_backend == "flex" and output_attentions:
1019
  raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
1020
 
1021
+ attention_mask_4d, flex_block_mask = get_attention_mask(
1022
  attn_backend=self.config.attn_backend,
1023
  batch_size=batch_size,
1024
  seq_len=seq_length,
1025
  device=device,
1026
+ attention_mask=attention_mask_2d,
1027
  )
1028
 
1029
  embedding_output = self.embeddings(
 
1034
  )
1035
  encoder_outputs = self.encoder(
1036
  embedding_output,
1037
+ attention_mask=attention_mask_4d,
1038
  head_mask=head_mask,
1039
  encoder_hidden_states=encoder_hidden_states,
1040
  encoder_attention_mask=encoder_extended_attention_mask,
 
1136
  if vocab_size is not None:
1137
  config.vocab_size = vocab_size
1138
  DPLM2PreTrainedModel.__init__(self, config)
1139
+ self.esm = FAST_DPLM2_ENCODER(config)
1140
  self.lm_head = EsmLMHead(config)
1141
  self.loss_fct = nn.CrossEntropyLoss()
1142
  self.post_init()
 
1237
  def __init__(self, config):
1238
  DPLM2PreTrainedModel.__init__(self, config)
1239
  self.num_labels = config.num_labels
1240
+ self.esm = FAST_DPLM2_ENCODER(config)
1241
  self.classifier = EsmClassificationHead(config)
1242
  self.mse = nn.MSELoss()
1243
  self.ce = nn.CrossEntropyLoss()
 
1314
  def __init__(self, config):
1315
  DPLM2PreTrainedModel.__init__(self, config)
1316
  self.num_labels = config.num_labels
1317
+ self.esm = FAST_DPLM2_ENCODER(config)
1318
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
1319
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1320
  self.loss_fct = nn.CrossEntropyLoss()