EscheWang commited on
Commit
0867146
·
verified ·
1 Parent(s): 73a97e6

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. config.json +28 -0
  2. config.py +127 -0
  3. model.py +970 -0
  4. model.safetensors +3 -0
  5. tokenizer.py +162 -0
  6. tokenizer_config.json +60 -0
  7. vocab.txt +10 -0
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UniRNAForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "emb_layer_norm_before": true,
7
+ "hidden_dropout_prob": 0.0,
8
+ "hidden_size": 1024,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-05,
12
+ "mask_token_id": 4,
13
+ "max_position_embeddings": 1026,
14
+ "model_type": "unirna",
15
+ "num_attention_heads": 16,
16
+ "num_hidden_layers": 16,
17
+ "pad_token_id": 0,
18
+ "position_embedding_type": "rotary",
19
+ "sep_token_id": 1,
20
+ "token_dropout": true,
21
+ "transformers_version": "4.26.1",
22
+ "vocab_size": 10,
23
+ "auto_map": {
24
+ "AutoConfig": "config.UniRNAConfig",
25
+ "AutoModel": "model.UniRNAModels",
26
+ "AutoModelForMaskedLM": "model.UniRNAForMaskedLM"
27
+ }
28
+ }
config.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class UniRNAConfig(PretrainedConfig):
7
+ """Configuration for UniRNA models."""
8
+
9
+ model_type: str = "unirna"
10
+
11
+ def __init__(
12
+ self,
13
+ vocab_size: int = 10,
14
+ hidden_size: int = 768,
15
+ num_hidden_layers: int = 12,
16
+ num_attention_heads: int = 12,
17
+ intermediate_size: int = 3072,
18
+ hidden_dropout_prob: float = 0.0,
19
+ attention_probs_dropout_prob: float = 0.0,
20
+ max_position_embeddings: int = 1026,
21
+ layer_norm_eps: float = 1e-5,
22
+ pad_token_id: int = 0,
23
+ sep_token_id: int = 1,
24
+ cls_token_id: int = 3,
25
+ mask_token_id: int = 4,
26
+ emb_layer_norm_before: bool = True,
27
+ token_dropout: bool = True,
28
+ position_embedding_type: str = "rotary",
29
+ use_flash_attention: bool = False,
30
+ tie_word_embeddings: bool = False,
31
+ is_decoder: bool = False,
32
+ **kwargs,
33
+ ):
34
+ # Ensure attribute exists before any access.
35
+ self.architectures = kwargs.get("architectures", None)
36
+ super().__init__(
37
+ pad_token_id=pad_token_id,
38
+ sep_token_id=sep_token_id,
39
+ cls_token_id=cls_token_id,
40
+ mask_token_id=mask_token_id,
41
+ tie_word_embeddings=tie_word_embeddings,
42
+ is_decoder=is_decoder,
43
+ **kwargs,
44
+ )
45
+ self.vocab_size = vocab_size
46
+ self.hidden_size = hidden_size
47
+ self.num_hidden_layers = num_hidden_layers
48
+ self.num_attention_heads = num_attention_heads
49
+ self.intermediate_size = intermediate_size
50
+ self.hidden_dropout_prob = hidden_dropout_prob
51
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
52
+ self.max_position_embeddings = max_position_embeddings
53
+ self.layer_norm_eps = layer_norm_eps
54
+ self.emb_layer_norm_before = emb_layer_norm_before
55
+ self.token_dropout = token_dropout
56
+ self.position_embedding_type = position_embedding_type
57
+ self.use_flash_attention = use_flash_attention
58
+ if self.architectures is None:
59
+ self.architectures = ["UniRNAForMaskedLM"]
60
+
61
+
62
+ def build_config(path):
63
+ path = os.path.splitext(path)[0]
64
+ name = os.path.basename(path)
65
+ model_type, num_hidden_layers, hidden_size, _ = name.split("_")[:4]
66
+ num_hidden_layers = int(num_hidden_layers[1:])
67
+ hidden_size = int(hidden_size[1:])
68
+ num_attention_heads = hidden_size // 64
69
+ intermediate_size = hidden_size * 3
70
+ config = UniRNAConfig(
71
+ model_type=model_type,
72
+ num_hidden_layers=num_hidden_layers,
73
+ hidden_size=hidden_size,
74
+ num_attention_heads=num_attention_heads,
75
+ intermediate_size=intermediate_size,
76
+ pad_token_id=0,
77
+ sep_token_id=1,
78
+ mask_token_id=4,
79
+ cls_token_id=3,
80
+ vocab_size=10,
81
+ emb_layer_norm_before=True,
82
+ layer_norm_eps=1e-5,
83
+ hidden_dropout_prob=0.0,
84
+ attention_probs_dropout_prob=0.0,
85
+ token_dropout=True,
86
+ initializer_range=0.02,
87
+ use_flash_attention=True,
88
+ max_position_embeddings=1026,
89
+ position_embedding_type="rotary",
90
+ tie_word_embeddings=False,
91
+ )
92
+ config._name_or_path = name
93
+ return config
94
+
95
+
96
+ def build_config_GENE(path, num_hidden_layers: int, hidden_size: int, vocab_size: int, model_type="GENE"):
97
+ path = os.path.splitext(path)[0]
98
+ name = os.path.basename(path)
99
+ # model_type, num_hidden_layers, hidden_size, _ = name.split("_")[:4]
100
+ num_hidden_layers = int(num_hidden_layers)
101
+ hidden_size = int(hidden_size)
102
+ num_attention_heads = hidden_size // 64
103
+ intermediate_size = hidden_size * 4
104
+ config = UniRNAConfig(
105
+ model_type=model_type,
106
+ num_hidden_layers=num_hidden_layers,
107
+ hidden_size=hidden_size,
108
+ num_attention_heads=num_attention_heads,
109
+ intermediate_size=intermediate_size,
110
+ pad_token_id=0,
111
+ sep_token_id=1,
112
+ mask_token_id=4,
113
+ cls_token_id=3,
114
+ vocab_size=vocab_size,
115
+ emb_layer_norm_before=True,
116
+ layer_norm_eps=1e-5,
117
+ hidden_dropout_prob=0.0,
118
+ attention_probs_dropout_prob=0.0,
119
+ token_dropout=True,
120
+ initializer_range=0.02,
121
+ use_flash_attention=True,
122
+ max_position_embeddings=1026,
123
+ position_embedding_type="rotary",
124
+ tie_word_embeddings=False,
125
+ )
126
+ config._name_or_path = name
127
+ return config
model.py ADDED
@@ -0,0 +1,970 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The code is modified from the EsmModel in the transformers library.
3
+ Sources: https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/modeling_esm.py
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from torch import nn
13
+ from torch.nn import CrossEntropyLoss
14
+ from transformers.modeling_outputs import (
15
+ BaseModelOutputWithPastAndCrossAttentions,
16
+ BaseModelOutputWithPoolingAndCrossAttentions,
17
+ MaskedLMOutput,
18
+ ModelOutput,
19
+ )
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.utils import logging
22
+
23
+ from .config import UniRNAConfig
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class UniRNASSPredictionOutput(ModelOutput):
30
+ loss: Optional[torch.FloatTensor] = None
31
+ logits: Optional[torch.FloatTensor] = None
32
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
33
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
34
+ pair_mask: Optional[torch.BoolTensor] = None
35
+
36
+
37
+ def rotate_half(x):
38
+ x1, x2 = x.chunk(2, dim=-1)
39
+ return torch.cat((-x2, x1), dim=-1)
40
+
41
+
42
+ def apply_rotary_pos_emb(x, cos, sin):
43
+ cos = cos[:, :, : x.shape[-2], :]
44
+ sin = sin[:, :, : x.shape[-2], :]
45
+
46
+ return (x * cos) + (rotate_half(x) * sin)
47
+
48
+
49
+ class RotaryEmbedding(nn.Module):
50
+ """
51
+ Rotary position embeddings based on those in
52
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
53
+ matrices which depend on their relative positions.
54
+ """
55
+
56
+ def __init__(self, dim: int):
57
+ super().__init__()
58
+ # Generate and save the inverse frequency buffer (non trainable)
59
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
60
+ inv_freq = inv_freq
61
+ self.register_buffer("inv_freq", inv_freq)
62
+
63
+ self._seq_len_cached = None
64
+ self._cos_cached = None
65
+ self._sin_cached = None
66
+
67
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
68
+ seq_len = x.shape[seq_dimension]
69
+
70
+ # Reset the tables if the sequence length has changed,
71
+ # or if we're on a new device (possibly due to tracing for instance)
72
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
73
+ self._seq_len_cached = seq_len
74
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
75
+ freqs = torch.outer(t, self.inv_freq)
76
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
77
+
78
+ self._cos_cached = emb.cos()[None, None, :, :]
79
+ self._sin_cached = emb.sin()[None, None, :, :]
80
+
81
+ return self._cos_cached, self._sin_cached
82
+
83
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
84
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
85
+
86
+ return (
87
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
88
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
89
+ )
90
+
91
+
92
+ class UniRNAEmbedding(nn.Module):
93
+ """
94
+ Same as BertEmbeddings with a additional token_dropout.
95
+ """
96
+
97
+ def __init__(self, config):
98
+ super().__init__()
99
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
100
+
101
+ if config.emb_layer_norm_before:
102
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
103
+ else:
104
+ self.layer_norm = None
105
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
106
+
107
+ self.padding_idx = config.pad_token_id
108
+ self.token_dropout = config.token_dropout
109
+ self.mask_token_id = config.mask_token_id
110
+
111
+ def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None):
112
+ if inputs_embeds is None:
113
+ inputs_embeds = self.word_embeddings(input_ids)
114
+
115
+ embeddings = inputs_embeds
116
+ if attention_mask is None:
117
+ attention_mask = torch.ones(embeddings.shape[:2], device=embeddings.device)
118
+
119
+ # By default, we use token dropout, similar to UniRNA.
120
+ if self.layer_norm is not None:
121
+ embeddings = self.layer_norm(embeddings)
122
+ if attention_mask is not None:
123
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
124
+
125
+ embeddings = self.dropout(embeddings)
126
+ if self.token_dropout and input_ids is not None:
127
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
128
+ # 0.15 is MaskedLM's default mask probability, and 0.8 is the default keep probability
129
+ mask_ratio_train = 0.15 * 0.8
130
+ src_lengths = attention_mask.sum(-1).clamp(min=1).to(embeddings.dtype)
131
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).to(embeddings.dtype) / src_lengths
132
+ denom = (1 - mask_ratio_observed).clamp(min=1e-6)
133
+ embeddings = (embeddings * (1 - mask_ratio_train) / denom[:, None, None]).to(embeddings.dtype)
134
+
135
+ return embeddings
136
+
137
+
138
+ class UniRNASelfAttention(nn.Module):
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
142
+ raise ValueError(
143
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
144
+ f"heads ({config.num_attention_heads})"
145
+ )
146
+
147
+ self.num_attention_heads = config.num_attention_heads
148
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
149
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
150
+
151
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
152
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
153
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
154
+
155
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
156
+
157
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
158
+
159
+ self.is_decoder = config.is_decoder
160
+
161
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.Tensor,
172
+ attention_mask: Optional[torch.FloatTensor] = None,
173
+ output_attentions: Optional[bool] = False,
174
+ ) -> Tuple[torch.Tensor]:
175
+ mixed_query_layer = self.query(hidden_states)
176
+
177
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
178
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
179
+ query_layer = self.transpose_for_scores(mixed_query_layer)
180
+
181
+ # Hardcoded from EsmModel provided by transformers
182
+ query_layer = query_layer * self.attention_head_size**-0.5
183
+
184
+ # Apply rotary embeddings
185
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
186
+
187
+ # Take the dot product between "query" and "key" to get the raw attention scores.
188
+ # For faster computation, you can used torch.nn.functional.scaled_dot_product_attention
189
+
190
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
191
+
192
+ if attention_mask is not None:
193
+ # Apply the attention mask is (precomputed for all layers in UniRNAModel forward() function)
194
+ attention_scores = attention_scores + attention_mask
195
+
196
+ # Normalize the attention scores to probabilities.
197
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
198
+
199
+ # This is actually dropping out entire tokens to attend to, which might
200
+ # seem a bit unusual, but is taken from the original Transformer paper.
201
+ attention_probs = self.dropout(attention_probs)
202
+
203
+ context_layer = torch.matmul(attention_probs, value_layer)
204
+
205
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
206
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
207
+ context_layer = context_layer.view(new_context_layer_shape)
208
+
209
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer, None)
210
+
211
+ return outputs
212
+
213
+
214
+ class UniRNAFlashSelfAttention(UniRNASelfAttention):
215
+ """Self-attention using PyTorch's scaled_dot_product_attention (SDPA) backend."""
216
+
217
+ def __init__(self, config):
218
+ super().__init__(config)
219
+ self.dropout_prob = config.attention_probs_dropout_prob
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states: torch.Tensor,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ output_attentions: Optional[bool] = False,
226
+ ) -> Tuple[torch.Tensor]:
227
+ if output_attentions:
228
+ raise ValueError("SDPA attention does not support output_attentions=True")
229
+
230
+ mixed_query_layer = self.query(hidden_states)
231
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
232
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
233
+ query_layer = self.transpose_for_scores(mixed_query_layer)
234
+
235
+ # Same manual scaling as UniRNASelfAttention
236
+ query_layer = query_layer * self.attention_head_size**-0.5
237
+
238
+ # Apply rotary embeddings
239
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
240
+
241
+ # Use PyTorch SDPA; scale=1.0 because we already scaled query above
242
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
243
+ query_layer,
244
+ key_layer,
245
+ value_layer,
246
+ attn_mask=attention_mask,
247
+ dropout_p=self.dropout_prob if self.training else 0.0,
248
+ scale=1.0,
249
+ )
250
+
251
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
252
+ new_shape = attn_output.size()[:-2] + (self.all_head_size,)
253
+ attn_output = attn_output.view(new_shape)
254
+
255
+ return (attn_output, None)
256
+
257
+
258
+ class UniRNASelfOutput(nn.Module):
259
+ def __init__(self, config):
260
+ super().__init__()
261
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
262
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
263
+
264
+ def forward(self, hidden_states, input_tensor):
265
+ hidden_states = self.dense(hidden_states)
266
+ hidden_states = self.dropout(hidden_states)
267
+ hidden_states = hidden_states + input_tensor
268
+ return hidden_states
269
+
270
+
271
+ class UniRNA_Attention(nn.Module):
272
+ def __init__(self, config):
273
+ super().__init__()
274
+
275
+ if getattr(config, "use_flash_attention", False):
276
+ self.self = UniRNAFlashSelfAttention(config)
277
+ else:
278
+ self.self = UniRNASelfAttention(config)
279
+ self.output = UniRNASelfOutput(config)
280
+ self.pruned_heads = set()
281
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
282
+
283
+ # TODO: add pruning heads
284
+ # def prune_heads(self, heads):
285
+ # if len(heads) == 0:
286
+ # return
287
+ # heads, index = find_pruneable_heads_and_indices(
288
+ # heads,
289
+ # self.self.num_attention_heads,
290
+ # self.self.attention_head_size,
291
+ # self.pruned_heads,
292
+ # )
293
+
294
+ # # Prune linear layers
295
+ # self.self.query = prune_linear_layer(self.self.query, index)
296
+ # self.self.key = prune_linear_layer(self.self.key, index)
297
+ # self.self.value = prune_linear_layer(self.self.value, index)
298
+ # self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
299
+
300
+ # # Update hyper params and store pruned heads
301
+ # self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
302
+ # self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
303
+ # self.pruned_heads = self.pruned_heads.union(heads)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states,
308
+ attention_mask=None,
309
+ output_attentions=False,
310
+ ):
311
+ hidden_states_ln = self.LayerNorm(hidden_states)
312
+ self_outputs = self.self(
313
+ hidden_states_ln,
314
+ attention_mask,
315
+ output_attentions,
316
+ )
317
+ attention_output = self.output(self_outputs[0], hidden_states)
318
+ return (attention_output, self_outputs[1])
319
+
320
+
321
+ class UniRNAIntermediate(nn.Module):
322
+ def __init__(self, config):
323
+ super().__init__()
324
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
325
+
326
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
327
+ hidden_states = self.dense(hidden_states)
328
+ hidden_states = nn.functional.gelu(hidden_states)
329
+ return hidden_states
330
+
331
+
332
+ class UniRNAOutput(nn.Module):
333
+ def __init__(self, config):
334
+ super().__init__()
335
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
336
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
337
+
338
+ def forward(self, hidden_states, input_tensor):
339
+ hidden_states = self.dense(hidden_states)
340
+ hidden_states = self.dropout(hidden_states)
341
+ hidden_states = hidden_states + input_tensor
342
+ return hidden_states
343
+
344
+
345
+ class UniRNALayer(nn.Module):
346
+ def __init__(self, config):
347
+ super().__init__()
348
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
349
+ self.seq_len_dim = 1
350
+ self.attention = UniRNA_Attention(config)
351
+ self.intermediate = UniRNAIntermediate(config)
352
+ self.output = UniRNAOutput(config)
353
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
354
+
355
+ def forward(
356
+ self,
357
+ hidden_states,
358
+ attention_mask=None,
359
+ output_attentions=False,
360
+ ):
361
+ self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions)
362
+ layer_output = self.feed_forward_chunk(self_attention_outputs[0])
363
+ return (layer_output, self_attention_outputs[1])
364
+
365
+ def feed_forward_chunk(self, attention_output):
366
+ attention_output_ln = self.LayerNorm(attention_output)
367
+ intermediate_output = self.intermediate(attention_output_ln)
368
+ layer_output = self.output(intermediate_output, attention_output)
369
+ return layer_output
370
+
371
+
372
+ class UniRNAEncoder(nn.Module):
373
+ def __init__(self, config):
374
+ super().__init__()
375
+ self.config = config
376
+ self.layer = nn.ModuleList([UniRNALayer(config) for _ in range(config.num_hidden_layers)])
377
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
378
+ self.gradient_checkpointing = False
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states,
383
+ attention_mask=None,
384
+ output_attentions=False,
385
+ output_hidden_states=False,
386
+ ):
387
+
388
+ all_hidden_states = () if output_hidden_states else None
389
+ all_self_attentions = () if output_attentions else None
390
+
391
+ for layer_module in self.layer:
392
+ if output_hidden_states:
393
+ all_hidden_states = all_hidden_states + (hidden_states,)
394
+
395
+ if self.gradient_checkpointing and self.training:
396
+ layer_outputs = self._gradient_checkpointing_func(
397
+ layer_module.__call__,
398
+ hidden_states,
399
+ attention_mask,
400
+ output_attentions,
401
+ )
402
+ else:
403
+ layer_outputs = layer_module(
404
+ hidden_states,
405
+ attention_mask,
406
+ output_attentions,
407
+ )
408
+ hidden_states = layer_outputs[0]
409
+ if output_attentions:
410
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
411
+
412
+ if self.emb_layer_norm_after:
413
+ hidden_states = self.emb_layer_norm_after(hidden_states)
414
+
415
+ if output_hidden_states:
416
+ all_hidden_states = all_hidden_states + (hidden_states,)
417
+
418
+ return BaseModelOutputWithPastAndCrossAttentions(
419
+ last_hidden_state=hidden_states,
420
+ hidden_states=all_hidden_states,
421
+ attentions=all_self_attentions,
422
+ )
423
+
424
+
425
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
426
+ class UniRNAPooler(nn.Module):
427
+ def __init__(self, config):
428
+ super().__init__()
429
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
430
+ self.activation = nn.Tanh()
431
+
432
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
433
+ # We "pool" the model by simply taking the hidden state corresponding
434
+ # to the first token.
435
+ first_token_tensor = hidden_states[:, 0]
436
+ pooled_output = self.dense(first_token_tensor)
437
+ pooled_output = self.activation(pooled_output)
438
+ return pooled_output
439
+
440
+
441
+ class UniRNAModel(PreTrainedModel):
442
+ config_class = UniRNAConfig
443
+ supports_gradient_checkpointing = True
444
+ main_input_name = "input_ids"
445
+ """
446
+
447
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
448
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
449
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
450
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
451
+ """
452
+
453
+ def __init__(self, config, add_pooling_layer=True):
454
+ super().__init__(config)
455
+ self.config = config
456
+ self.embeddings = UniRNAEmbedding(config)
457
+ self.encoder = UniRNAEncoder(config)
458
+ self.pooler = UniRNAPooler(config) if add_pooling_layer else None
459
+
460
+ use_flash_attention = getattr(config, "use_flash_attention", False)
461
+ if use_flash_attention:
462
+ logger.info("Using Uni-RNA SDPA Attention")
463
+ else:
464
+ logger.info("Using Uni-RNA Attention")
465
+
466
+ # Initialize weights and apply final processing
467
+ self.post_init()
468
+
469
+ def _set_gradient_checkpointing(self, enable: bool, gradient_checkpointing_func=None):
470
+ self.encoder.gradient_checkpointing = enable
471
+ if gradient_checkpointing_func is not None:
472
+ self.encoder._gradient_checkpointing_func = gradient_checkpointing_func
473
+
474
+ def get_input_embeddings(self):
475
+ return self.embeddings.word_embeddings
476
+
477
+ def set_input_embeddings(self, value):
478
+ self.embeddings.word_embeddings = value
479
+
480
+ def _prune_heads(self, heads_to_prune):
481
+ """
482
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
483
+ class PreTrainedModel
484
+ """
485
+ for layer, heads in heads_to_prune.items():
486
+ self.encoder.layer[layer].attention.prune_heads(heads)
487
+
488
+ def forward(
489
+ self,
490
+ input_ids: Optional[torch.Tensor] = None,
491
+ attention_mask: Optional[torch.Tensor] = None,
492
+ inputs_embeds: Optional[torch.Tensor] = None,
493
+ output_attentions: Optional[bool] = None,
494
+ output_hidden_states: Optional[bool] = None,
495
+ return_dict: Optional[bool] = None,
496
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
497
+ r"""
498
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
499
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
500
+ the model is configured as a decoder.
501
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
502
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
503
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
504
+
505
+ - 1 for tokens that are **not masked**,
506
+ - 0 for tokens that are **masked**.
507
+ past_key_values (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
508
+ Tuple of length `config.n_layers`. Each tuple has 4 tensors of shape
509
+ `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`. Contains precomputed key and value
510
+ hidden states of the attention blocks. Can be used to speed up decoding.
511
+
512
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
513
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
514
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
515
+ use_cache (`bool`, *optional*):
516
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
517
+ `past_key_values`).
518
+ """
519
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
520
+ output_hidden_states = (
521
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
522
+ )
523
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
524
+
525
+ input_shape, attention_mask = self._validate_and_shape_inputs(input_ids, inputs_embeds, attention_mask)
526
+ extended_attention_mask = self._prepare_attention_mask(attention_mask, input_shape)
527
+ embedding_output = self._compute_embedding_output(input_ids, attention_mask, inputs_embeds)
528
+ encoder_outputs = self.encoder(
529
+ embedding_output,
530
+ attention_mask=extended_attention_mask,
531
+ output_attentions=output_attentions,
532
+ output_hidden_states=output_hidden_states,
533
+ )
534
+ sequence_output, pooled_output = self._pool_outputs(encoder_outputs[0], attention_mask)
535
+
536
+ if not return_dict:
537
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
538
+ return output
539
+
540
+ return BaseModelOutputWithPoolingAndCrossAttentions(
541
+ last_hidden_state=sequence_output,
542
+ pooler_output=pooled_output,
543
+ past_key_values=encoder_outputs.past_key_values,
544
+ hidden_states=encoder_outputs.hidden_states,
545
+ attentions=encoder_outputs.attentions,
546
+ cross_attentions=encoder_outputs.cross_attentions,
547
+ )
548
+
549
+ def _validate_and_shape_inputs(
550
+ self,
551
+ input_ids: Optional[torch.Tensor],
552
+ inputs_embeds: Optional[torch.Tensor],
553
+ attention_mask: Optional[torch.Tensor],
554
+ ) -> Tuple[Tuple[int, ...], torch.Tensor]:
555
+ if input_ids is not None and inputs_embeds is not None:
556
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
557
+ if input_ids is None and inputs_embeds is None:
558
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
559
+
560
+ if input_ids is not None:
561
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
562
+ input_shape = input_ids.size()
563
+ device = input_ids.device
564
+ else:
565
+ input_shape = inputs_embeds.size()[:-1]
566
+ device = inputs_embeds.device
567
+
568
+ batch_size, seq_length = input_shape
569
+ if attention_mask is None:
570
+ attention_mask = torch.ones((batch_size, seq_length), device=device)
571
+ return input_shape, attention_mask
572
+
573
+ def _prepare_attention_mask(self, attention_mask: torch.Tensor, input_shape: Tuple[int, ...]) -> torch.Tensor:
574
+ return self.get_extended_attention_mask(attention_mask, input_shape)
575
+
576
+ def _compute_embedding_output(
577
+ self,
578
+ input_ids: Optional[torch.Tensor],
579
+ attention_mask: torch.Tensor,
580
+ inputs_embeds: Optional[torch.Tensor],
581
+ ) -> torch.Tensor:
582
+ return self.embeddings(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
583
+
584
+ def _pool_outputs(
585
+ self, sequence_output: torch.Tensor, attention_mask: torch.Tensor
586
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
587
+ # make it compatible with deepprotein which wraps the model with different pooler
588
+ try:
589
+ pooled_output = self.pooler(sequence_output, attention_mask) if self.pooler is not None else None
590
+ except TypeError:
591
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
592
+ return sequence_output, pooled_output
593
+
594
+
595
+ class UniRNAForMaskedLM(PreTrainedModel):
596
+ _tied_weights_keys = ["lm_head.decoder.weight"]
597
+ config_class = UniRNAConfig
598
+ supports_gradient_checkpointing = True
599
+ main_input_name = "input_ids"
600
+
601
+ def __init__(self, config):
602
+ super().__init__(config)
603
+
604
+ self.config = config
605
+ self.embeddings = UniRNAEmbedding(config)
606
+ self.encoder = UniRNAEncoder(config)
607
+ self.lm_head = UniRNALMHead(config)
608
+
609
+ self.post_init()
610
+
611
+ def _set_gradient_checkpointing(self, enable: bool, gradient_checkpointing_func=None):
612
+ self.encoder.gradient_checkpointing = enable
613
+ if gradient_checkpointing_func is not None:
614
+ self.encoder._gradient_checkpointing_func = gradient_checkpointing_func
615
+
616
+ def get_input_embeddings(self):
617
+ return self.embeddings.word_embeddings
618
+
619
+ def set_input_embeddings(self, value):
620
+ self.embeddings.word_embeddings = value
621
+
622
+ def get_output_embeddings(self):
623
+ return self.lm_head.decoder
624
+
625
+ def set_output_embeddings(self, new_embeddings):
626
+ self.lm_head.decoder = new_embeddings
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: Optional[torch.Tensor] = None,
631
+ attention_mask: Optional[torch.Tensor] = None,
632
+ inputs_embeds: Optional[torch.Tensor] = None,
633
+ labels: Optional[torch.Tensor] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ return_dict: Optional[bool] = None,
637
+ ) -> Union[Tuple, MaskedLMOutput]:
638
+ r"""
639
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
640
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
641
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
642
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
643
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
644
+ Used to hide legacy arguments that have been deprecated.
645
+ """
646
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
647
+ output_hidden_states = (
648
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
649
+ )
650
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
651
+
652
+ if input_ids is not None and inputs_embeds is not None:
653
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
654
+ elif input_ids is not None:
655
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
656
+ input_shape = input_ids.size()
657
+ elif inputs_embeds is not None:
658
+ input_shape = inputs_embeds.size()[:-1]
659
+ else:
660
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
661
+
662
+ batch_size, seq_length = input_shape
663
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
664
+
665
+ if attention_mask is None:
666
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
667
+
668
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
669
+
670
+ embedding_output = self.embeddings(
671
+ input_ids=input_ids,
672
+ attention_mask=attention_mask,
673
+ inputs_embeds=inputs_embeds,
674
+ )
675
+
676
+ encoder_outputs = self.encoder(
677
+ embedding_output,
678
+ attention_mask=extended_attention_mask,
679
+ output_attentions=output_attentions,
680
+ output_hidden_states=output_hidden_states,
681
+ )
682
+ sequence_output = encoder_outputs[0]
683
+
684
+ prediction_scores = self.lm_head(sequence_output)
685
+
686
+ loss = None
687
+ if labels is not None:
688
+ loss_fct = CrossEntropyLoss()
689
+ loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
690
+
691
+ if not return_dict:
692
+ output = (prediction_scores,) + encoder_outputs[1:]
693
+ return ((loss,) + output) if loss is not None else output
694
+
695
+ return MaskedLMOutput(
696
+ loss=loss,
697
+ logits=prediction_scores,
698
+ hidden_states=encoder_outputs.hidden_states,
699
+ attentions=encoder_outputs.attentions,
700
+ )
701
+
702
+
703
+ class UniRNAForSSPredict(PreTrainedModel):
704
+ """
705
+ TODO: make it compatible with transformers, create new 'modeling_outputs' class for SS prediction
706
+ """
707
+
708
+ config_class = UniRNAConfig
709
+ supports_gradient_checkpointing = True
710
+ main_input_name = "input_ids"
711
+
712
+ def __init__(self, config):
713
+ # Explicitly block usage until this head is trained and validated.
714
+ raise RuntimeError(
715
+ "UniRNAForSSPredict is disabled and not supported. This head is untrained and must not be called."
716
+ )
717
+
718
+ def _set_gradient_checkpointing(self, enable: bool, gradient_checkpointing_func=None):
719
+ self.encoder.gradient_checkpointing = enable
720
+ if gradient_checkpointing_func is not None:
721
+ self.encoder._gradient_checkpointing_func = gradient_checkpointing_func
722
+
723
+ def get_input_embeddings(self):
724
+ return self.embeddings.word_embeddings
725
+
726
+ def set_input_embeddings(self, value):
727
+ self.embeddings.word_embeddings = value
728
+
729
+ def forward(
730
+ self,
731
+ input_ids: Optional[torch.LongTensor] = None,
732
+ attention_mask: Optional[torch.Tensor] = None,
733
+ inputs_embeds: Optional[torch.Tensor] = None,
734
+ labels: Optional[torch.Tensor] = None,
735
+ output_attentions: Optional[bool] = None,
736
+ output_hidden_states: Optional[bool] = None,
737
+ return_dict: Optional[bool] = None,
738
+ ) -> Union[Tuple, UniRNASSPredictionOutput]:
739
+ r"""
740
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
741
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
742
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
743
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
744
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
745
+ Used to hide legacy arguments that have been deprecated.
746
+ """
747
+
748
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
749
+ output_hidden_states = (
750
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
751
+ )
752
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
753
+
754
+ if input_ids is not None and inputs_embeds is not None:
755
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
756
+ elif input_ids is not None:
757
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
758
+ input_shape = input_ids.size()
759
+ elif inputs_embeds is not None:
760
+ input_shape = inputs_embeds.size()[:-1]
761
+ else:
762
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
763
+
764
+ batch_size, seq_length = input_shape
765
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
766
+
767
+ if attention_mask is None:
768
+ attention_mask = torch.ones(((batch_size, seq_length)), device=device)
769
+
770
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
771
+
772
+ embedding_output = self.embeddings(
773
+ input_ids=input_ids,
774
+ attention_mask=attention_mask,
775
+ inputs_embeds=inputs_embeds,
776
+ )
777
+
778
+ encoder_outputs = self.encoder(
779
+ embedding_output,
780
+ attention_mask=extended_attention_mask,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ )
784
+
785
+ sequence_output = encoder_outputs[0]
786
+ logits, pair_mask = self.heads(sequence_output, attention_mask=attention_mask, return_mask=True)
787
+
788
+ loss = None
789
+ if labels is not None:
790
+ if labels.dim() == 3:
791
+ labels = labels.unsqueeze(-1)
792
+ if labels.shape[1] == logits.shape[1] + 2 and labels.shape[2] == logits.shape[2] + 2:
793
+ labels = labels[:, 1:-1, 1:-1, :]
794
+ labels = labels.to(logits.dtype)
795
+ loss_fct = nn.BCEWithLogitsLoss()
796
+ if pair_mask is not None:
797
+ loss = loss_fct(logits[pair_mask], labels[pair_mask])
798
+ else:
799
+ loss = loss_fct(logits, labels)
800
+
801
+ if not return_dict:
802
+ output = (logits, encoder_outputs.hidden_states, encoder_outputs.attentions, pair_mask)
803
+ return ((loss,) + output) if loss is not None else output
804
+
805
+ return UniRNASSPredictionOutput(
806
+ loss=loss,
807
+ logits=logits,
808
+ hidden_states=encoder_outputs.hidden_states,
809
+ attentions=encoder_outputs.attentions,
810
+ pair_mask=pair_mask,
811
+ )
812
+
813
+
814
+ class UniRNALMHead(nn.Module):
815
+ """UniRNA Head for masked language modeling."""
816
+
817
+ def __init__(self, config):
818
+ super().__init__()
819
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
820
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
821
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
822
+
823
+ def forward(self, features):
824
+ x = self.dense(features)
825
+ x = nn.functional.gelu(x)
826
+ x = self.layer_norm(x)
827
+
828
+ # project back to size of vocabulary with bias
829
+ x = self.decoder(x)
830
+ return x
831
+
832
+
833
+ class Dense(nn.Module):
834
+ def __init__(
835
+ self,
836
+ in_features: int,
837
+ out_features: int,
838
+ norm: str = "LayerNorm",
839
+ activation: str = "ReLU",
840
+ dropout: float = 0.1,
841
+ pool: str = "AdaptiveAvgPool1d",
842
+ bias: bool = True,
843
+ residual: bool = True,
844
+ ) -> None:
845
+ super().__init__()
846
+ self.residual = residual
847
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
848
+ self.norm = getattr(nn, norm)(out_features) if norm else nn.Identity()
849
+ self.activation = getattr(nn, activation)() if activation else nn.Identity()
850
+ self.dropout = nn.Dropout(dropout)
851
+ self.pool = getattr(nn, pool)(out_features) if pool else nn.Identity() if self.residual else None
852
+
853
+ def forward(self, x):
854
+ out = self.linear(x)
855
+ out = self.norm(out)
856
+ out = self.activation(out)
857
+ out = self.dropout(out)
858
+ if self.residual:
859
+ out = out + self.pool(x)
860
+ return out
861
+
862
+
863
+ class MLP(nn.Module):
864
+ def __init__(
865
+ self,
866
+ *features: Sequence[int],
867
+ norm: str = "LayerNorm",
868
+ activation: str = "ReLU",
869
+ dropout: float = 0.1,
870
+ pool: str = "AdaptiveAvgPool1d",
871
+ bias: bool = True,
872
+ residual: bool = True,
873
+ linear_output: bool = True
874
+ ) -> None:
875
+ super().__init__()
876
+ if len(features) == 0 and isinstance(features, Sequence):
877
+ features = features[0] # type: ignore[assignment]
878
+ if not len(features) > 1:
879
+ raise ValueError(f"`features` of MLP should have at least 2 elements, but got {len(features)}")
880
+ dense = partial(
881
+ Dense,
882
+ norm=norm,
883
+ activation=activation,
884
+ dropout=dropout,
885
+ pool=pool,
886
+ bias=bias,
887
+ residual=residual,
888
+ )
889
+ if linear_output:
890
+ layers = [dense(in_features, out_features) for in_features, out_features in zip(features, features[1:-1])]
891
+ layers.append(nn.Linear(features[-2], features[-1], bias=bias))
892
+ else:
893
+ layers = [dense(in_features, out_features) for in_features, out_features in zip(features, features[1:])]
894
+ self.layers = nn.Sequential(*layers)
895
+
896
+ def forward(self, x):
897
+ return self.layers(x)
898
+
899
+
900
+ class UniRNASSHead(nn.Module):
901
+ """UniRNA head for Secondary Structure Prediction"""
902
+
903
+ def __init__(self, config) -> None:
904
+ super().__init__()
905
+
906
+ self.qk_proj = nn.Linear(config.hidden_size, 2 * config.hidden_size)
907
+ self.ffn = MLP(1, config.hidden_size, residual=False)
908
+ self.linear = nn.Linear(config.hidden_size, 1)
909
+
910
+ def forward(self, features, attention_mask: Optional[torch.Tensor] = None, return_mask: bool = False):
911
+ x = features[:, 1:-1] # remove CLS and EOS tokens
912
+ q, k = self.qk_proj(x).chunk(2, dim=-1)
913
+ contact_map = (q @ k.transpose(-2, -1)).unsqueeze(-1)
914
+ contact_map = contact_map + self.ffn(contact_map)
915
+ logits = self.linear(contact_map)
916
+
917
+ pair_mask = None
918
+ if attention_mask is not None:
919
+ core_mask = attention_mask[:, 1:-1].bool()
920
+ pair_mask = core_mask.unsqueeze(-1) & core_mask.unsqueeze(-2)
921
+ pair_mask = pair_mask.unsqueeze(-1)
922
+ logits = logits.masked_fill(~pair_mask, 0.0)
923
+
924
+ return (logits, pair_mask) if return_mask else logits
925
+
926
+
927
+ class AvgPooler(nn.Module):
928
+ def __init__(self):
929
+ super().__init__()
930
+
931
+ def forward(self, hidden_states, attention_mask=None):
932
+ if attention_mask is None:
933
+ attention_mask = torch.ones(hidden_states.shape[:2], device=hidden_states.device, dtype=torch.bool)
934
+ else:
935
+ attention_mask = attention_mask.bool()
936
+
937
+ if hidden_states.size(1) > 2:
938
+ core_states = hidden_states[:, 1:-1, :]
939
+ core_mask = attention_mask[:, 1:-1]
940
+ else:
941
+ core_states = hidden_states
942
+ core_mask = attention_mask
943
+
944
+ core_mask = core_mask.unsqueeze(-1)
945
+ masked_states = core_states * core_mask
946
+ denom = core_mask.sum(dim=1).clamp(min=1).to(hidden_states.dtype)
947
+ return masked_states.sum(dim=1) / denom
948
+
949
+
950
+ class UniRNAModels(UniRNAModel):
951
+ config_class = UniRNAConfig
952
+ supports_gradient_checkpointing = True
953
+
954
+ def __init__(self, *args, **kwargs):
955
+ super().__init__(*args, **kwargs)
956
+
957
+ # We didn't include weight for original pooler, so we replace it with a simple cls pooler
958
+ del self.pooler
959
+ self.pooler = AvgPooler()
960
+
961
+
962
+ class UniRNAForMLM(UniRNAForMaskedLM):
963
+ config_class = UniRNAConfig
964
+ supports_gradient_checkpointing = True
965
+
966
+ def __init__(self, *args, **kwargs):
967
+ super().__init__(*args, **kwargs)
968
+
969
+ # We didn't include weight for original pooler, so we replace it with a simple cls pooler
970
+ self.pooler = AvgPooler()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e1f93036483ef3f89e47136f1fce27a05d2178fc4a0659dd0b9e89dea3219e7
3
+ size 676213784
tokenizer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Beijing DP Technology Co.,Ltd. All rights reserved.
2
+
3
+ """
4
+ The code is modified from the original ESM tokenizer provided by HuggingFace.
5
+ Sources: https://github.com/huggingface/transformers/blob/main/src/transformers/models/esm/tokenization_esm.py
6
+ """
7
+
8
+ import os
9
+ from typing import List, Optional, Union
10
+
11
+ from transformers.tokenization_utils import PreTrainedTokenizer
12
+ from transformers.tokenization_utils_base import AddedToken
13
+
14
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
15
+
16
+
17
+ def load_vocab_file(vocab_file):
18
+ """Load vocabulary tokens from file into a list of strings."""
19
+ with open(vocab_file, "r") as f:
20
+ lines = f.read().splitlines()
21
+ return [line.strip() for line in lines]
22
+
23
+
24
+ class UniRNATokenizer(PreTrainedTokenizer):
25
+ """
26
+ Constructs an UniRNA tokenizer, based on ESM tokenizer provided by HuggingFace.
27
+ """
28
+
29
+ vocab_files_names = VOCAB_FILES_NAMES
30
+ model_input_names = ["input_ids", "attention_mask"]
31
+
32
+ def __init__(
33
+ self,
34
+ vocab_file,
35
+ unk_token="N",
36
+ cls_token="<cls>",
37
+ pad_token="<pad>",
38
+ mask_token="<mask>",
39
+ eos_token="<eos>",
40
+ replace_uracil: bool = False,
41
+ **kwargs,
42
+ ):
43
+ self.all_tokens = load_vocab_file(vocab_file)
44
+ self._id_to_token = dict(enumerate(self.all_tokens))
45
+ self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
46
+ super().__init__(
47
+ unk_token=unk_token,
48
+ cls_token=cls_token,
49
+ pad_token=pad_token,
50
+ mask_token=mask_token,
51
+ eos_token=eos_token,
52
+ **kwargs,
53
+ )
54
+
55
+ # Optional compatibility switch for DNA-only workflows.
56
+ if replace_uracil and "U" in self._token_to_id and "T" in self._token_to_id:
57
+ self._token_to_id["U"] = self._token_to_id["T"]
58
+ self.unique_no_split_tokens = self.all_tokens
59
+ self._update_trie(self.unique_no_split_tokens)
60
+
61
+ def _convert_token_to_id(self, token: str) -> int:
62
+ token = token.upper() if token not in self.all_special_tokens else token
63
+ unk_id = self._token_to_id.get(self.unk_token)
64
+ if unk_id is None:
65
+ unk_id = self.unk_token_id
66
+ return self._token_to_id.get(token, unk_id)
67
+
68
+ def _convert_id_to_token(self, index: int) -> str:
69
+ return self._id_to_token.get(index, self.unk_token)
70
+
71
+ def token_to_id(self, token: str) -> int:
72
+ return self._convert_token_to_id(token)
73
+
74
+ def id_to_token(self, index: int) -> str:
75
+ return self._convert_id_to_token(index)
76
+
77
+ def _tokenize(self, text, **kwargs):
78
+ text = text.strip()
79
+ if not text:
80
+ return []
81
+ if any(ch.isspace() for ch in text):
82
+ return text.split()
83
+ return list(text)
84
+
85
+ def get_vocab_size(self, with_added_tokens=False):
86
+ if with_added_tokens:
87
+ return len(self.get_vocab())
88
+ return len(self._id_to_token)
89
+
90
+ def get_vocab(self):
91
+ vocab = {token: i for i, token in enumerate(self.all_tokens)}
92
+ vocab.update(self.added_tokens_encoder)
93
+ return vocab
94
+
95
+ def build_inputs_with_special_tokens(
96
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
97
+ ) -> List[int]:
98
+ cls = [self.cls_token_id]
99
+ sep = [self.eos_token_id]
100
+ if token_ids_1 is None:
101
+ if self.eos_token_id is None:
102
+ return cls + token_ids_0
103
+ else:
104
+ return cls + token_ids_0 + sep
105
+ elif self.eos_token_id is None:
106
+ raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
107
+ return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token
108
+
109
+ def get_special_tokens_mask(
110
+ self,
111
+ token_ids_0: List,
112
+ token_ids_1: Optional[List] = None,
113
+ already_has_special_tokens: bool = False,
114
+ ) -> List[int]:
115
+ """
116
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
117
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
118
+
119
+ Args:
120
+ token_ids_0 (`List[int]`):
121
+ List of ids of the first sequence.
122
+ token_ids_1 (`List[int]`, *optional*):
123
+ List of ids of the second sequence.
124
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
125
+ Whether or not the token list is already formatted with special tokens for the model.
126
+
127
+ Returns:
128
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
129
+ """
130
+ if already_has_special_tokens:
131
+ if token_ids_1 is not None:
132
+ raise ValueError(
133
+ "You should not supply a second sequence if the provided sequence of "
134
+ "ids is already formatted with special tokens for the model."
135
+ )
136
+
137
+ return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
138
+ mask = [1] + ([0] * len(token_ids_0)) + [1]
139
+ if token_ids_1 is not None:
140
+ mask += [0] * len(token_ids_1) + [1]
141
+ return mask
142
+
143
+ def save_vocabulary(self, save_directory, filename_prefix=None):
144
+ os.makedirs(save_directory, exist_ok=True)
145
+ vocab_file = os.path.join(
146
+ save_directory,
147
+ (filename_prefix + "-" if filename_prefix else "") + "vocab.txt",
148
+ )
149
+ with open(vocab_file, "w") as f:
150
+ f.write("\n".join(self.all_tokens))
151
+ return (vocab_file,)
152
+
153
+ @property
154
+ def vocab_size(self) -> int:
155
+ return self.get_vocab_size(with_added_tokens=False)
156
+
157
+ def _add_tokens(
158
+ self,
159
+ new_tokens: Union[List[str], List[AddedToken]],
160
+ special_tokens: bool = False,
161
+ ) -> int:
162
+ return super()._add_tokens(new_tokens, special_tokens=special_tokens)
tokenizer_config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "N",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<cls>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "<mask>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "backend": "custom",
45
+ "cls_token": "<cls>",
46
+ "eos_token": "<eos>",
47
+ "is_local": true,
48
+ "mask_token": "<mask>",
49
+ "model_max_length": 1000000000000000019884624838656,
50
+ "pad_token": "<pad>",
51
+ "sep_token": "<eos>",
52
+ "tokenizer_class": "UniRNATokenizer",
53
+ "unk_token": "N",
54
+ "auto_map": {
55
+ "AutoTokenizer": [
56
+ "tokenizer.UniRNATokenizer",
57
+ null
58
+ ]
59
+ }
60
+ }
vocab.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <pad>
2
+ <eos>
3
+ N
4
+ <cls>
5
+ <mask>
6
+ A
7
+ T
8
+ C
9
+ G
10
+ U