Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +119 -79
modeling_esm_plusplus.py
CHANGED
|
@@ -391,23 +391,38 @@ except ImportError:
|
|
| 391 |
flex_attention = None
|
| 392 |
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
|
| 396 |
-
assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
|
| 397 |
-
token_valid = attention_mask_2d.bool()
|
| 398 |
-
batch_size, seq_len = token_valid.shape
|
| 399 |
-
|
| 400 |
-
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 401 |
-
return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
|
| 402 |
-
|
| 403 |
-
return create_block_mask(
|
| 404 |
-
mask_mod,
|
| 405 |
-
batch_size,
|
| 406 |
-
1,
|
| 407 |
-
seq_len,
|
| 408 |
-
seq_len,
|
| 409 |
-
device=attention_mask_2d.device,
|
| 410 |
-
)
|
| 411 |
|
| 412 |
|
| 413 |
class ESMplusplusConfig(PretrainedConfig):
|
|
@@ -702,14 +717,15 @@ class MultiHeadAttention(nn.Module):
|
|
| 702 |
def forward(
|
| 703 |
self,
|
| 704 |
x: torch.Tensor,
|
| 705 |
-
attention_mask:
|
| 706 |
-
flex_block_mask:
|
| 707 |
output_attentions: bool = False,
|
| 708 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 709 |
"""
|
| 710 |
Args:
|
| 711 |
x: Input tensor
|
| 712 |
-
attention_mask:
|
|
|
|
| 713 |
output_attentions: Whether to return attention weights
|
| 714 |
|
| 715 |
Returns:
|
|
@@ -727,24 +743,15 @@ class MultiHeadAttention(nn.Module):
|
|
| 727 |
scale = 1 / math.sqrt(self.d_head)
|
| 728 |
|
| 729 |
if output_attentions: # Manual attention computation
|
| 730 |
-
b, h, l, _ = query_BHLD.shape
|
| 731 |
-
attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
|
| 732 |
-
if attention_mask is not None:
|
| 733 |
-
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
| 734 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 735 |
-
attn_weights
|
| 736 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 737 |
context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
| 738 |
else:
|
| 739 |
if self.attn_backend == "flex":
|
| 740 |
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 741 |
-
assert query_BHLD.dtype in (torch.float16, torch.bfloat16),
|
| 742 |
-
|
| 743 |
-
)
|
| 744 |
-
if attention_mask is not None:
|
| 745 |
-
assert flex_block_mask is not None, (
|
| 746 |
-
"Flex attention backend requires a block mask when attention_mask is provided."
|
| 747 |
-
)
|
| 748 |
context_BHLD = flex_attention(
|
| 749 |
query_BHLD,
|
| 750 |
key_BHLD,
|
|
@@ -753,15 +760,11 @@ class MultiHeadAttention(nn.Module):
|
|
| 753 |
scale=scale,
|
| 754 |
)
|
| 755 |
else:
|
| 756 |
-
sdpa_mask = None
|
| 757 |
-
if attention_mask is not None:
|
| 758 |
-
sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
|
| 759 |
-
sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
| 760 |
context_BHLD = F.scaled_dot_product_attention(
|
| 761 |
query_BHLD,
|
| 762 |
key_BHLD,
|
| 763 |
value_BHLD,
|
| 764 |
-
attn_mask=
|
| 765 |
scale=scale,
|
| 766 |
)
|
| 767 |
|
|
@@ -820,14 +823,15 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 820 |
def forward(
|
| 821 |
self,
|
| 822 |
x: torch.Tensor,
|
| 823 |
-
attention_mask:
|
| 824 |
-
flex_block_mask:
|
| 825 |
output_attentions: bool = False,
|
| 826 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 827 |
"""
|
| 828 |
Args:
|
| 829 |
x: Input tensor
|
| 830 |
-
attention_mask:
|
|
|
|
| 831 |
output_attentions: Whether to return attention weights
|
| 832 |
|
| 833 |
Returns:
|
|
@@ -902,13 +906,13 @@ class TransformerStack(nn.Module):
|
|
| 902 |
self,
|
| 903 |
x: torch.Tensor,
|
| 904 |
attention_mask: Optional[torch.Tensor] = None,
|
| 905 |
-
output_hidden_states: bool = False,
|
| 906 |
-
output_attentions: bool = False,
|
| 907 |
) -> TransformerOutput:
|
| 908 |
"""
|
| 909 |
Args:
|
| 910 |
x: Input tensor
|
| 911 |
-
attention_mask: Optional attention mask
|
| 912 |
output_hidden_states: Whether to return all hidden states
|
| 913 |
output_attentions: Whether to return attention weights
|
| 914 |
|
|
@@ -918,33 +922,31 @@ class TransformerStack(nn.Module):
|
|
| 918 |
hidden_states = () if output_hidden_states else None
|
| 919 |
attentions = () if output_attentions else None
|
| 920 |
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
attention_mask = None
|
| 930 |
-
else:
|
| 931 |
-
pairwise_attention_mask = token_attention_mask.unsqueeze(-1) & token_attention_mask.unsqueeze(-2)
|
| 932 |
-
attention_mask = pairwise_attention_mask.unsqueeze(1)
|
| 933 |
-
flex_block_mask = None
|
| 934 |
-
else:
|
| 935 |
-
flex_block_mask = None
|
| 936 |
|
| 937 |
for block in self.blocks:
|
| 938 |
if self.gradient_checkpointing and self.training:
|
| 939 |
x, attn_weights = self._gradient_checkpointing_func(
|
| 940 |
block.__call__,
|
| 941 |
-
x,
|
| 942 |
-
attention_mask,
|
| 943 |
-
flex_block_mask,
|
| 944 |
-
output_attentions,
|
| 945 |
)
|
| 946 |
else:
|
| 947 |
-
x, attn_weights = block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
|
| 949 |
if attentions is not None:
|
| 950 |
attentions += (attn_weights,)
|
|
@@ -952,9 +954,13 @@ class TransformerStack(nn.Module):
|
|
| 952 |
if output_hidden_states:
|
| 953 |
assert hidden_states is not None
|
| 954 |
hidden_states += (x,)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
|
| 956 |
return TransformerOutput(
|
| 957 |
-
last_hidden_state=
|
| 958 |
hidden_states=hidden_states,
|
| 959 |
attentions=attentions
|
| 960 |
)
|
|
@@ -1048,7 +1054,12 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 1048 |
|
| 1049 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1050 |
x = self.embed(input_ids)
|
| 1051 |
-
return self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1052 |
|
| 1053 |
def forward(
|
| 1054 |
self,
|
|
@@ -1072,11 +1083,20 @@ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 1072 |
Returns:
|
| 1073 |
TransformerOutput containing last hidden state and optionally all hidden states and attention weights
|
| 1074 |
"""
|
|
|
|
|
|
|
|
|
|
| 1075 |
if inputs_embeds is None:
|
| 1076 |
x = self.embed(input_ids)
|
| 1077 |
else:
|
| 1078 |
x = inputs_embeds
|
| 1079 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1080 |
|
| 1081 |
|
| 1082 |
class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
@@ -1116,7 +1136,12 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 1116 |
|
| 1117 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1118 |
x = self.embed(input_ids)
|
| 1119 |
-
return self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1120 |
|
| 1121 |
def forward(
|
| 1122 |
self,
|
|
@@ -1146,16 +1171,24 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
| 1146 |
x = self.embed(input_ids)
|
| 1147 |
else:
|
| 1148 |
x = inputs_embeds
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1152 |
loss = None
|
| 1153 |
if labels is not None:
|
| 1154 |
loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
|
|
|
|
| 1155 |
return ESMplusplusOutput(
|
| 1156 |
loss=loss,
|
| 1157 |
logits=logits,
|
| 1158 |
-
last_hidden_state=
|
| 1159 |
hidden_states=output.hidden_states,
|
| 1160 |
attentions=output.attentions,
|
| 1161 |
)
|
|
@@ -1185,7 +1218,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 1185 |
|
| 1186 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1187 |
x = self.embed(input_ids)
|
| 1188 |
-
return self.transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1189 |
|
| 1190 |
def forward(
|
| 1191 |
self,
|
|
@@ -1219,9 +1257,11 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 1219 |
output_attentions=output_attentions,
|
| 1220 |
output_hidden_states=output_hidden_states
|
| 1221 |
)
|
| 1222 |
-
|
| 1223 |
-
|
|
|
|
| 1224 |
logits = self.classifier(features)
|
|
|
|
| 1225 |
loss = None
|
| 1226 |
if labels is not None:
|
| 1227 |
labels = labels.to(logits.device)
|
|
@@ -1246,7 +1286,7 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 1246 |
return ESMplusplusOutput(
|
| 1247 |
loss=loss,
|
| 1248 |
logits=logits,
|
| 1249 |
-
last_hidden_state=
|
| 1250 |
hidden_states=output.hidden_states,
|
| 1251 |
attentions=output.attentions,
|
| 1252 |
)
|
|
@@ -1302,15 +1342,17 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
|
|
| 1302 |
output_attentions=output_attentions,
|
| 1303 |
output_hidden_states=output_hidden_states
|
| 1304 |
)
|
| 1305 |
-
|
| 1306 |
-
|
|
|
|
| 1307 |
loss = None
|
| 1308 |
if labels is not None:
|
| 1309 |
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
| 1310 |
return ESMplusplusOutput(
|
| 1311 |
loss=loss,
|
| 1312 |
logits=logits,
|
| 1313 |
-
last_hidden_state=
|
| 1314 |
hidden_states=output.hidden_states,
|
| 1315 |
attentions=output.attentions,
|
| 1316 |
)
|
|
@@ -1487,5 +1529,3 @@ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
|
|
| 1487 |
@property
|
| 1488 |
def special_token_ids(self):
|
| 1489 |
return self.all_special_ids
|
| 1490 |
-
|
| 1491 |
-
|
|
|
|
| 391 |
flex_attention = None
|
| 392 |
|
| 393 |
|
| 394 |
+
def get_attention_mask(
|
| 395 |
+
attn_backend: str,
|
| 396 |
+
batch_size: int,
|
| 397 |
+
seq_len: int,
|
| 398 |
+
device: torch.device,
|
| 399 |
+
attention_mask: Optional[torch.Tensor] = None
|
| 400 |
+
) -> torch.Tensor:
|
| 401 |
+
if attention_mask is None:
|
| 402 |
+
token_attention_mask = torch.ones((batch_size, seq_len), device=device).bool()
|
| 403 |
+
else:
|
| 404 |
+
token_attention_mask = attention_mask.bool()
|
| 405 |
+
|
| 406 |
+
if attn_backend == "flex":
|
| 407 |
+
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
| 408 |
+
|
| 409 |
+
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 410 |
+
return token_attention_mask[batch_idx, q_idx] & token_attention_mask[batch_idx, kv_idx]
|
| 411 |
+
|
| 412 |
+
flex_block_mask = create_block_mask(
|
| 413 |
+
mask_mod,
|
| 414 |
+
batch_size,
|
| 415 |
+
1,
|
| 416 |
+
seq_len,
|
| 417 |
+
seq_len,
|
| 418 |
+
device=device,
|
| 419 |
+
)
|
| 420 |
+
extended_attention_mask = None
|
| 421 |
+
else:
|
| 422 |
+
flex_block_mask = None
|
| 423 |
+
extended_attention_mask = token_attention_mask[:, None, :, None] & token_attention_mask[:, None, None, :]
|
| 424 |
|
| 425 |
+
return extended_attention_mask, flex_block_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
|
| 428 |
class ESMplusplusConfig(PretrainedConfig):
|
|
|
|
| 717 |
def forward(
|
| 718 |
self,
|
| 719 |
x: torch.Tensor,
|
| 720 |
+
attention_mask: torch.Tensor,
|
| 721 |
+
flex_block_mask: object,
|
| 722 |
output_attentions: bool = False,
|
| 723 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 724 |
"""
|
| 725 |
Args:
|
| 726 |
x: Input tensor
|
| 727 |
+
attention_mask: 4D attention mask
|
| 728 |
+
flex_block_mask: Flex attention block mask
|
| 729 |
output_attentions: Whether to return attention weights
|
| 730 |
|
| 731 |
Returns:
|
|
|
|
| 743 |
scale = 1 / math.sqrt(self.d_head)
|
| 744 |
|
| 745 |
if output_attentions: # Manual attention computation
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
| 747 |
+
attn_weights = attn_weights.masked_fill(attention_mask.logical_not(), float('-inf'))
|
| 748 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
| 749 |
context_BHLD = torch.matmul(attn_weights, value_BHLD)
|
| 750 |
else:
|
| 751 |
if self.attn_backend == "flex":
|
| 752 |
assert flex_attention is not None, "Flex attention backend requested but torch.flex_attention is unavailable."
|
| 753 |
+
assert query_BHLD.dtype in (torch.float16, torch.bfloat16), f"Flex attention backend requires float16 or bfloat16, got {query_BHLD.dtype}."
|
| 754 |
+
assert flex_block_mask is not None, "Flex attention backend requires a block mask when attention_mask is provided."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
context_BHLD = flex_attention(
|
| 756 |
query_BHLD,
|
| 757 |
key_BHLD,
|
|
|
|
| 760 |
scale=scale,
|
| 761 |
)
|
| 762 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
context_BHLD = F.scaled_dot_product_attention(
|
| 764 |
query_BHLD,
|
| 765 |
key_BHLD,
|
| 766 |
value_BHLD,
|
| 767 |
+
attn_mask=attention_mask,
|
| 768 |
scale=scale,
|
| 769 |
)
|
| 770 |
|
|
|
|
| 823 |
def forward(
|
| 824 |
self,
|
| 825 |
x: torch.Tensor,
|
| 826 |
+
attention_mask: torch.Tensor,
|
| 827 |
+
flex_block_mask: object,
|
| 828 |
output_attentions: bool = False,
|
| 829 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 830 |
"""
|
| 831 |
Args:
|
| 832 |
x: Input tensor
|
| 833 |
+
attention_mask: 4D attention mask
|
| 834 |
+
flex_block_mask: Flex attention block mask
|
| 835 |
output_attentions: Whether to return attention weights
|
| 836 |
|
| 837 |
Returns:
|
|
|
|
| 906 |
self,
|
| 907 |
x: torch.Tensor,
|
| 908 |
attention_mask: Optional[torch.Tensor] = None,
|
| 909 |
+
output_hidden_states: Optional[bool] = False,
|
| 910 |
+
output_attentions: Optional[bool] = False,
|
| 911 |
) -> TransformerOutput:
|
| 912 |
"""
|
| 913 |
Args:
|
| 914 |
x: Input tensor
|
| 915 |
+
attention_mask: Optional 2D attention mask
|
| 916 |
output_hidden_states: Whether to return all hidden states
|
| 917 |
output_attentions: Whether to return attention weights
|
| 918 |
|
|
|
|
| 922 |
hidden_states = () if output_hidden_states else None
|
| 923 |
attentions = () if output_attentions else None
|
| 924 |
|
| 925 |
+
# move to 4D attention mask or flex block mask
|
| 926 |
+
attention_mask, flex_block_mask = get_attention_mask(
|
| 927 |
+
attn_backend=self.attn_backend,
|
| 928 |
+
batch_size=x.shape[0],
|
| 929 |
+
seq_len=x.shape[1],
|
| 930 |
+
device=x.device,
|
| 931 |
+
attention_mask=attention_mask,
|
| 932 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 933 |
|
| 934 |
for block in self.blocks:
|
| 935 |
if self.gradient_checkpointing and self.training:
|
| 936 |
x, attn_weights = self._gradient_checkpointing_func(
|
| 937 |
block.__call__,
|
| 938 |
+
x=x,
|
| 939 |
+
attention_mask=attention_mask,
|
| 940 |
+
flex_block_mask=flex_block_mask,
|
| 941 |
+
output_attentions=output_attentions,
|
| 942 |
)
|
| 943 |
else:
|
| 944 |
+
x, attn_weights = block(
|
| 945 |
+
x=x,
|
| 946 |
+
attention_mask=attention_mask,
|
| 947 |
+
flex_block_mask=flex_block_mask,
|
| 948 |
+
output_attentions=output_attentions,
|
| 949 |
+
)
|
| 950 |
|
| 951 |
if attentions is not None:
|
| 952 |
attentions += (attn_weights,)
|
|
|
|
| 954 |
if output_hidden_states:
|
| 955 |
assert hidden_states is not None
|
| 956 |
hidden_states += (x,)
|
| 957 |
+
|
| 958 |
+
last_hidden_state = self.norm(x)
|
| 959 |
+
if output_hidden_states:
|
| 960 |
+
hidden_states += (last_hidden_state,)
|
| 961 |
|
| 962 |
return TransformerOutput(
|
| 963 |
+
last_hidden_state=last_hidden_state,
|
| 964 |
hidden_states=hidden_states,
|
| 965 |
attentions=attentions
|
| 966 |
)
|
|
|
|
| 1054 |
|
| 1055 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1056 |
x = self.embed(input_ids)
|
| 1057 |
+
return self.transformer(
|
| 1058 |
+
x=x,
|
| 1059 |
+
attention_mask=attention_mask,
|
| 1060 |
+
output_hidden_states=False,
|
| 1061 |
+
output_attentions=False,
|
| 1062 |
+
).last_hidden_state
|
| 1063 |
|
| 1064 |
def forward(
|
| 1065 |
self,
|
|
|
|
| 1083 |
Returns:
|
| 1084 |
TransformerOutput containing last hidden state and optionally all hidden states and attention weights
|
| 1085 |
"""
|
| 1086 |
+
assert input_ids is not None or inputs_embeds is not None, "You have to specify either input_ids or inputs_embeds"
|
| 1087 |
+
assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds at the same time"
|
| 1088 |
+
|
| 1089 |
if inputs_embeds is None:
|
| 1090 |
x = self.embed(input_ids)
|
| 1091 |
else:
|
| 1092 |
x = inputs_embeds
|
| 1093 |
+
|
| 1094 |
+
return self.transformer(
|
| 1095 |
+
x=x,
|
| 1096 |
+
attention_mask=attention_mask,
|
| 1097 |
+
output_hidden_states=output_hidden_states,
|
| 1098 |
+
output_attentions=output_attentions,
|
| 1099 |
+
).last_hidden_state
|
| 1100 |
|
| 1101 |
|
| 1102 |
class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
|
|
|
|
| 1136 |
|
| 1137 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1138 |
x = self.embed(input_ids)
|
| 1139 |
+
return self.transformer(
|
| 1140 |
+
x=x,
|
| 1141 |
+
attention_mask=attention_mask,
|
| 1142 |
+
output_hidden_states=False,
|
| 1143 |
+
output_attentions=False,
|
| 1144 |
+
).last_hidden_state
|
| 1145 |
|
| 1146 |
def forward(
|
| 1147 |
self,
|
|
|
|
| 1171 |
x = self.embed(input_ids)
|
| 1172 |
else:
|
| 1173 |
x = inputs_embeds
|
| 1174 |
+
|
| 1175 |
+
output = self.transformer(
|
| 1176 |
+
x=x,
|
| 1177 |
+
attention_mask=attention_mask,
|
| 1178 |
+
output_hidden_states=output_hidden_states,
|
| 1179 |
+
output_attentions=output_attentions,
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
last_hidden_state = output.last_hidden_state
|
| 1183 |
+
logits = self.sequence_head(last_hidden_state)
|
| 1184 |
loss = None
|
| 1185 |
if labels is not None:
|
| 1186 |
loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
|
| 1187 |
+
|
| 1188 |
return ESMplusplusOutput(
|
| 1189 |
loss=loss,
|
| 1190 |
logits=logits,
|
| 1191 |
+
last_hidden_state=last_hidden_state,
|
| 1192 |
hidden_states=output.hidden_states,
|
| 1193 |
attentions=output.attentions,
|
| 1194 |
)
|
|
|
|
| 1218 |
|
| 1219 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1220 |
x = self.embed(input_ids)
|
| 1221 |
+
return self.transformer(
|
| 1222 |
+
x=x,
|
| 1223 |
+
attention_mask=attention_mask,
|
| 1224 |
+
output_hidden_states=False,
|
| 1225 |
+
output_attentions=False,
|
| 1226 |
+
).last_hidden_state
|
| 1227 |
|
| 1228 |
def forward(
|
| 1229 |
self,
|
|
|
|
| 1257 |
output_attentions=output_attentions,
|
| 1258 |
output_hidden_states=output_hidden_states
|
| 1259 |
)
|
| 1260 |
+
|
| 1261 |
+
last_hidden_state = output.last_hidden_state
|
| 1262 |
+
features = self.pooler(last_hidden_state, attention_mask) # pooler expects 2d attention mask
|
| 1263 |
logits = self.classifier(features)
|
| 1264 |
+
|
| 1265 |
loss = None
|
| 1266 |
if labels is not None:
|
| 1267 |
labels = labels.to(logits.device)
|
|
|
|
| 1286 |
return ESMplusplusOutput(
|
| 1287 |
loss=loss,
|
| 1288 |
logits=logits,
|
| 1289 |
+
last_hidden_state=last_hidden_state,
|
| 1290 |
hidden_states=output.hidden_states,
|
| 1291 |
attentions=output.attentions,
|
| 1292 |
)
|
|
|
|
| 1342 |
output_attentions=output_attentions,
|
| 1343 |
output_hidden_states=output_hidden_states
|
| 1344 |
)
|
| 1345 |
+
|
| 1346 |
+
last_hidden_state = output.last_hidden_state
|
| 1347 |
+
logits = self.classifier(last_hidden_state)
|
| 1348 |
loss = None
|
| 1349 |
if labels is not None:
|
| 1350 |
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1351 |
+
|
| 1352 |
return ESMplusplusOutput(
|
| 1353 |
loss=loss,
|
| 1354 |
logits=logits,
|
| 1355 |
+
last_hidden_state=last_hidden_state,
|
| 1356 |
hidden_states=output.hidden_states,
|
| 1357 |
attentions=output.attentions,
|
| 1358 |
)
|
|
|
|
| 1529 |
@property
|
| 1530 |
def special_token_ids(self):
|
| 1531 |
return self.all_special_ids
|
|
|
|
|
|