Upload modeling_dplm2.py with huggingface_hub
Browse files- modeling_dplm2.py +18 -16
modeling_dplm2.py
CHANGED
|
@@ -419,9 +419,9 @@ def get_attention_mask(
|
|
| 419 |
attention_mask: Optional[torch.Tensor] = None,
|
| 420 |
) -> Tuple[Optional[torch.Tensor], Optional[object]]:
|
| 421 |
if attention_mask is None:
|
| 422 |
-
|
| 423 |
else:
|
| 424 |
-
|
| 425 |
|
| 426 |
if attn_backend == "flex":
|
| 427 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
@@ -429,8 +429,10 @@ def get_attention_mask(
|
|
| 429 |
if attention_mask is None:
|
| 430 |
flex_block_mask = None
|
| 431 |
else:
|
|
|
|
|
|
|
| 432 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 433 |
-
return (
|
| 434 |
|
| 435 |
flex_block_mask = create_block_mask(
|
| 436 |
mask_mod,
|
|
@@ -440,12 +442,12 @@ def get_attention_mask(
|
|
| 440 |
seq_len,
|
| 441 |
device=device,
|
| 442 |
)
|
| 443 |
-
|
| 444 |
else:
|
| 445 |
flex_block_mask = None
|
| 446 |
-
|
| 447 |
|
| 448 |
-
return
|
| 449 |
|
| 450 |
|
| 451 |
def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
@@ -989,12 +991,12 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 989 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 990 |
|
| 991 |
if attention_mask is None:
|
| 992 |
-
|
| 993 |
elif attention_mask.dim() == 2:
|
| 994 |
-
|
| 995 |
elif attention_mask.dim() == 4:
|
| 996 |
assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
|
| 997 |
-
|
| 998 |
else:
|
| 999 |
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
| 1000 |
|
|
@@ -1009,19 +1011,19 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1009 |
|
| 1010 |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1011 |
|
| 1012 |
-
embedding_attention_mask =
|
| 1013 |
if embedding_attention_mask is None and input_ids is not None:
|
| 1014 |
embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 1015 |
|
| 1016 |
if self.config.attn_backend == "flex" and output_attentions:
|
| 1017 |
raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
|
| 1018 |
|
| 1019 |
-
|
| 1020 |
attn_backend=self.config.attn_backend,
|
| 1021 |
batch_size=batch_size,
|
| 1022 |
seq_len=seq_length,
|
| 1023 |
device=device,
|
| 1024 |
-
attention_mask=
|
| 1025 |
)
|
| 1026 |
|
| 1027 |
embedding_output = self.embeddings(
|
|
@@ -1032,7 +1034,7 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1032 |
)
|
| 1033 |
encoder_outputs = self.encoder(
|
| 1034 |
embedding_output,
|
| 1035 |
-
attention_mask=
|
| 1036 |
head_mask=head_mask,
|
| 1037 |
encoder_hidden_states=encoder_hidden_states,
|
| 1038 |
encoder_attention_mask=encoder_extended_attention_mask,
|
|
@@ -1134,7 +1136,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1134 |
if vocab_size is not None:
|
| 1135 |
config.vocab_size = vocab_size
|
| 1136 |
DPLM2PreTrainedModel.__init__(self, config)
|
| 1137 |
-
self.esm =
|
| 1138 |
self.lm_head = EsmLMHead(config)
|
| 1139 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 1140 |
self.post_init()
|
|
@@ -1235,7 +1237,7 @@ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1235 |
def __init__(self, config):
|
| 1236 |
DPLM2PreTrainedModel.__init__(self, config)
|
| 1237 |
self.num_labels = config.num_labels
|
| 1238 |
-
self.esm =
|
| 1239 |
self.classifier = EsmClassificationHead(config)
|
| 1240 |
self.mse = nn.MSELoss()
|
| 1241 |
self.ce = nn.CrossEntropyLoss()
|
|
@@ -1312,7 +1314,7 @@ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1312 |
def __init__(self, config):
|
| 1313 |
DPLM2PreTrainedModel.__init__(self, config)
|
| 1314 |
self.num_labels = config.num_labels
|
| 1315 |
-
self.esm =
|
| 1316 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1317 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1318 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
|
| 419 |
attention_mask: Optional[torch.Tensor] = None,
|
| 420 |
) -> Tuple[Optional[torch.Tensor], Optional[object]]:
|
| 421 |
if attention_mask is None:
|
| 422 |
+
attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
|
| 423 |
else:
|
| 424 |
+
attention_mask_2d = attention_mask.bool()
|
| 425 |
|
| 426 |
if attn_backend == "flex":
|
| 427 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
|
|
| 429 |
if attention_mask is None:
|
| 430 |
flex_block_mask = None
|
| 431 |
else:
|
| 432 |
+
valid_lens = attention_mask_2d.sum(dim=-1)
|
| 433 |
+
|
| 434 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 435 |
+
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
|
| 436 |
|
| 437 |
flex_block_mask = create_block_mask(
|
| 438 |
mask_mod,
|
|
|
|
| 442 |
seq_len,
|
| 443 |
device=device,
|
| 444 |
)
|
| 445 |
+
attention_mask_4d = None
|
| 446 |
else:
|
| 447 |
flex_block_mask = None
|
| 448 |
+
attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
|
| 449 |
|
| 450 |
+
return attention_mask_4d, flex_block_mask
|
| 451 |
|
| 452 |
|
| 453 |
def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 991 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 992 |
|
| 993 |
if attention_mask is None:
|
| 994 |
+
attention_mask_2d = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
|
| 995 |
elif attention_mask.dim() == 2:
|
| 996 |
+
attention_mask_2d = attention_mask.bool()
|
| 997 |
elif attention_mask.dim() == 4:
|
| 998 |
assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
|
| 999 |
+
attention_mask_2d = input_ids.ne(self.config.pad_token_id)
|
| 1000 |
else:
|
| 1001 |
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
| 1002 |
|
|
|
|
| 1011 |
|
| 1012 |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1013 |
|
| 1014 |
+
embedding_attention_mask = attention_mask_2d
|
| 1015 |
if embedding_attention_mask is None and input_ids is not None:
|
| 1016 |
embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 1017 |
|
| 1018 |
if self.config.attn_backend == "flex" and output_attentions:
|
| 1019 |
raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
|
| 1020 |
|
| 1021 |
+
attention_mask_4d, flex_block_mask = get_attention_mask(
|
| 1022 |
attn_backend=self.config.attn_backend,
|
| 1023 |
batch_size=batch_size,
|
| 1024 |
seq_len=seq_length,
|
| 1025 |
device=device,
|
| 1026 |
+
attention_mask=attention_mask_2d,
|
| 1027 |
)
|
| 1028 |
|
| 1029 |
embedding_output = self.embeddings(
|
|
|
|
| 1034 |
)
|
| 1035 |
encoder_outputs = self.encoder(
|
| 1036 |
embedding_output,
|
| 1037 |
+
attention_mask=attention_mask_4d,
|
| 1038 |
head_mask=head_mask,
|
| 1039 |
encoder_hidden_states=encoder_hidden_states,
|
| 1040 |
encoder_attention_mask=encoder_extended_attention_mask,
|
|
|
|
| 1136 |
if vocab_size is not None:
|
| 1137 |
config.vocab_size = vocab_size
|
| 1138 |
DPLM2PreTrainedModel.__init__(self, config)
|
| 1139 |
+
self.esm = FAST_DPLM2_ENCODER(config)
|
| 1140 |
self.lm_head = EsmLMHead(config)
|
| 1141 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 1142 |
self.post_init()
|
|
|
|
| 1237 |
def __init__(self, config):
|
| 1238 |
DPLM2PreTrainedModel.__init__(self, config)
|
| 1239 |
self.num_labels = config.num_labels
|
| 1240 |
+
self.esm = FAST_DPLM2_ENCODER(config)
|
| 1241 |
self.classifier = EsmClassificationHead(config)
|
| 1242 |
self.mse = nn.MSELoss()
|
| 1243 |
self.ce = nn.CrossEntropyLoss()
|
|
|
|
| 1314 |
def __init__(self, config):
|
| 1315 |
DPLM2PreTrainedModel.__init__(self, config)
|
| 1316 |
self.num_labels = config.num_labels
|
| 1317 |
+
self.esm = FAST_DPLM2_ENCODER(config)
|
| 1318 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1319 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1320 |
self.loss_fct = nn.CrossEntropyLoss()
|