GinnM commited on
Commit
a0583a7
·
verified ·
1 Parent(s): c8e3b68

Upload ProSSTForMaskedLM

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +45 -0
  3. configuration_prosst.py +71 -0
  4. model.safetensors +3 -0
  5. modeling_prosst.py +1198 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "oracle_checkpoint3/ss_2051_0_aa2pos_pos2aa_aa2ss_ss2aa_False/ProSSTX-2048-2",
3
+ "architectures": [
4
+ "ProSSTForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_prosst.ProSSTConfig",
9
+ "AutoModelForMaskedLM": "modeling_prosst.ProSSTForMaskedLM"
10
+ },
11
+ "hidden_act": "gelu",
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 1024,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 4096,
16
+ "layer_norm_eps": 1e-05,
17
+ "mask_token_id": 24,
18
+ "max_position_embeddings": -1,
19
+ "max_relative_positions": 1024,
20
+ "mlm_probability": 0.15,
21
+ "model_type": "ProSST",
22
+ "num_attention_heads": 16,
23
+ "num_hidden_layers": 24,
24
+ "pad_token_id": 0,
25
+ "pooler_dropout": 0.1,
26
+ "pooler_hidden_act": "gelu",
27
+ "pooler_hidden_size": 1024,
28
+ "pooling_head": "mean",
29
+ "pos_att_type": [
30
+ "aa2pos",
31
+ "pos2aa",
32
+ "aa2ss",
33
+ "ss2aa"
34
+ ],
35
+ "position_biased_input": false,
36
+ "position_embedding_type": "relative",
37
+ "relative_attention": true,
38
+ "scale_hidden": 1,
39
+ "ss_vocab_size": 2051,
40
+ "token_dropout": true,
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.44.2",
43
+ "type_vocab_size": 0,
44
+ "vocab_size": 25
45
+ }
configuration_prosst.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ProSSTConfig(PretrainedConfig):
4
+ model_type = "ProSST"
5
+
6
+ def __init__(
7
+ self,
8
+ token_dropout=True,
9
+ mlm_probability=0.15,
10
+ vocab_size=1024,
11
+ type_vocab_size=0,
12
+ ss_vocab_size=0,
13
+ hidden_size=768,
14
+ num_hidden_layers=12,
15
+ num_attention_heads=12,
16
+ intermediate_size=3072,
17
+ hidden_act="gelu",
18
+ hidden_dropout_prob=0.1,
19
+ attention_probs_dropout_prob=0.1,
20
+ mask_token_id=24,
21
+ initializer_range=0.02,
22
+ layer_norm_eps=1e-7,
23
+ pad_token_id=0,
24
+ position_biased_input=False,
25
+ pooler_dropout=0,
26
+ pooler_hidden_act="gelu",
27
+ pos_att_type=None,
28
+ position_embedding_type="relative",
29
+ max_position_embeddings=1024,
30
+ max_relative_positions=1024,
31
+ relative_attention=False,
32
+ pooling_head="mean",
33
+ scale_hidden=1,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.token_dropout = token_dropout
38
+ self.mlm_probability = mlm_probability
39
+ self.hidden_size = hidden_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_attention_heads = num_attention_heads
42
+ self.intermediate_size = intermediate_size
43
+ self.hidden_act = hidden_act
44
+ self.hidden_dropout_prob = hidden_dropout_prob
45
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.type_vocab_size = type_vocab_size
48
+ self.ss_vocab_size = ss_vocab_size
49
+ self.initializer_range = initializer_range
50
+ self.relative_attention = relative_attention
51
+ self.max_relative_positions = max_relative_positions
52
+ self.pad_token_id = pad_token_id
53
+ self.position_biased_input = position_biased_input
54
+ self.mask_token_id = mask_token_id
55
+ self.position_embedding_type = position_embedding_type
56
+ self.pooling_head = pooling_head
57
+ self.scale_hidden = scale_hidden
58
+
59
+ # Backwards compatibility
60
+ if type(pos_att_type) == str:
61
+ pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")]
62
+
63
+ self.pos_att_type = pos_att_type
64
+ self.vocab_size = vocab_size
65
+ self.layer_norm_eps = layer_norm_eps
66
+
67
+ self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
68
+ self.pooler_dropout = pooler_dropout
69
+ self.pooler_hidden_act = pooler_hidden_act
70
+
71
+ ProSSTConfig.register_for_auto_class()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aced67de1bcff1ea82d8159a1ede7e102206dec8c865d2588462842066062da
3
+ size 1633510992
modeling_prosst.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from typing import Optional, Tuple, Union
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from transformers.activations import ACT2FN
8
+ from transformers.modeling_outputs import (
9
+ BaseModelOutput,
10
+ MaskedLMOutput,
11
+ SequenceClassifierOutput,
12
+ TokenClassifierOutput,
13
+ )
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from .configuration_prosst import ProSSTConfig
16
+ import torch.nn.functional as F
17
+ from functools import partial
18
+
19
+
20
+ def rbf(values, v_min, v_max, n_bins=16):
21
+ """
22
+ Returns RBF encodings in a new dimension at the end.
23
+ https://github.com/evolutionaryscale/esm/blob/main/esm/utils/misc.py
24
+ """
25
+ rbf_centers = torch.linspace(
26
+ v_min, v_max, n_bins, device=values.device, dtype=values.dtype
27
+ )
28
+ rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
29
+ rbf_std = (v_max - v_min) / n_bins
30
+ z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
31
+ return torch.exp(-(z**2))
32
+
33
+
34
+ def build_relative_position(query_size, key_size, device):
35
+ """
36
+ Build relative position according to the query and key
37
+
38
+ We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
39
+ \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
40
+ P_k\\)
41
+
42
+ Args:
43
+ query_size (int): the length of query
44
+ key_size (int): the length of key
45
+
46
+ Return:
47
+ `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
48
+
49
+ """
50
+
51
+ q_ids = torch.arange(query_size, dtype=torch.long, device=device)
52
+ k_ids = torch.arange(key_size, dtype=torch.long, device=device)
53
+ rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
54
+ rel_pos_ids = rel_pos_ids[:query_size, :]
55
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
56
+ return rel_pos_ids
57
+
58
+
59
+ @torch.jit.script
60
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
61
+ return c2p_pos.expand(
62
+ [
63
+ query_layer.size(0),
64
+ query_layer.size(1),
65
+ query_layer.size(2),
66
+ relative_pos.size(-1),
67
+ ]
68
+ )
69
+
70
+
71
+ @torch.jit.script
72
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
73
+ return c2p_pos.expand(
74
+ [
75
+ query_layer.size(0),
76
+ query_layer.size(1),
77
+ key_layer.size(-2),
78
+ key_layer.size(-2),
79
+ ]
80
+ )
81
+
82
+
83
+ @torch.jit.script
84
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
85
+ return pos_index.expand(
86
+ p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
87
+ )
88
+
89
+
90
+ def rotate_half(x):
91
+ x1, x2 = x.chunk(2, dim=-1)
92
+ return torch.cat((-x2, x1), dim=-1)
93
+
94
+
95
+ def apply_rotary_pos_emb(x, cos, sin):
96
+ cos = cos[:, :, : x.shape[-2], :]
97
+ sin = sin[:, :, : x.shape[-2], :]
98
+
99
+ return (x * cos) + (rotate_half(x) * sin)
100
+
101
+
102
+ class RotaryEmbedding(torch.nn.Module):
103
+ """
104
+ Rotary position embeddings based on those in
105
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
106
+ matrices which depend on their relative positions.
107
+ """
108
+
109
+ def __init__(self, dim: int):
110
+ super().__init__()
111
+ # Generate and save the inverse frequency buffer (non trainable)
112
+ inv_freq = 1.0 / (
113
+ 10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
114
+ )
115
+ inv_freq = inv_freq
116
+ self.register_buffer("inv_freq", inv_freq)
117
+
118
+ self._seq_len_cached = None
119
+ self._cos_cached = None
120
+ self._sin_cached = None
121
+
122
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
123
+ seq_len = x.shape[seq_dimension]
124
+
125
+ # Reset the tables if the sequence length has changed,
126
+ # or if we're on a new device (possibly due to tracing for instance)
127
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
128
+ self._seq_len_cached = seq_len
129
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(
130
+ self.inv_freq
131
+ )
132
+ freqs = torch.outer(t, self.inv_freq)
133
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
134
+
135
+ self._cos_cached = emb.cos()[None, None, :, :]
136
+ self._sin_cached = emb.sin()[None, None, :, :]
137
+
138
+ return self._cos_cached, self._sin_cached
139
+
140
+ def forward(
141
+ self, q: torch.Tensor, k: torch.Tensor
142
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
143
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
144
+ k, seq_dimension=-2
145
+ )
146
+
147
+ return (
148
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
149
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
150
+ )
151
+
152
+
153
+ class MaskedConv1d(nn.Conv1d):
154
+ """A masked 1-dimensional convolution layer.
155
+
156
+ Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically.
157
+
158
+ Shape:
159
+ Input: (N, L, in_channels)
160
+ input_mask: (N, L, 1), optional
161
+ Output: (N, L, out_channels)
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ out_channels: int,
168
+ kernel_size: int,
169
+ stride: int = 1,
170
+ dilation: int = 1,
171
+ groups: int = 1,
172
+ bias: bool = True,
173
+ ):
174
+ """
175
+ :param in_channels: input channels
176
+ :param out_channels: output channels
177
+ :param kernel_size: the kernel width
178
+ :param stride: filter shift
179
+ :param dilation: dilation factor
180
+ :param groups: perform depth-wise convolutions
181
+ :param bias: adds learnable bias to output
182
+ """
183
+ padding = dilation * (kernel_size - 1) // 2
184
+ super().__init__(
185
+ in_channels,
186
+ out_channels,
187
+ kernel_size,
188
+ stride=stride,
189
+ dilation=dilation,
190
+ groups=groups,
191
+ bias=bias,
192
+ padding=padding,
193
+ )
194
+
195
+ def forward(self, x, input_mask=None):
196
+ if input_mask is not None:
197
+ x = x * input_mask
198
+ return super().forward(x.transpose(1, 2)).transpose(1, 2)
199
+
200
+
201
+ class Attention1dPooling(nn.Module):
202
+ def __init__(self, config):
203
+ super().__init__()
204
+ self.layer = MaskedConv1d(config.hidden_size, 1, 1)
205
+
206
+ def forward(self, x, input_mask=None):
207
+ batch_szie = x.shape[0]
208
+ attn = self.layer(x)
209
+ attn = attn.view(batch_szie, -1)
210
+ if input_mask is not None:
211
+ attn = attn.masked_fill_(
212
+ ~input_mask.view(batch_szie, -1).bool(), float("-inf")
213
+ )
214
+ attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1)
215
+ out = (attn * x).sum(dim=1)
216
+ return out
217
+
218
+
219
+ class MeanPooling(nn.Module):
220
+ """Mean Pooling for sentence-level classification tasks."""
221
+
222
+ def __init__(self):
223
+ super().__init__()
224
+
225
+ def forward(self, features, input_mask=None):
226
+ if input_mask is not None:
227
+ # Applying input_mask to zero out masked values
228
+ masked_features = features * input_mask.unsqueeze(2)
229
+ sum_features = torch.sum(masked_features, dim=1)
230
+ mean_pooled_features = sum_features / input_mask.sum(dim=1, keepdim=True)
231
+ else:
232
+ mean_pooled_features = torch.mean(features, dim=1)
233
+ return mean_pooled_features
234
+
235
+
236
+ class ContextPooler(nn.Module):
237
+ def __init__(self, config):
238
+ super().__init__()
239
+ scale_hidden = getattr(config, "scale_hidden", 1)
240
+ if config.pooling_head == "mean":
241
+ self.mean_pooling = MeanPooling()
242
+ elif config.pooling_head == "attention":
243
+ self.mean_pooling = Attention1dPooling(config)
244
+ self.dense = nn.Linear(
245
+ config.pooler_hidden_size, scale_hidden * config.pooler_hidden_size
246
+ )
247
+ self.dropout = nn.Dropout(config.pooler_dropout)
248
+ self.config = config
249
+
250
+ def forward(self, hidden_states, input_mask=None):
251
+ # We "pool" the model by simply taking the hidden state corresponding
252
+ # to the first token.
253
+
254
+ context_token = self.mean_pooling(hidden_states, input_mask)
255
+ context_token = self.dropout(context_token)
256
+ pooled_output = self.dense(context_token)
257
+ pooled_output = torch.tanh(pooled_output)
258
+ return pooled_output
259
+
260
+ @property
261
+ def output_dim(self):
262
+ return self.config.hidden_size
263
+
264
+
265
+ class ProSSTLayerNorm(nn.Module):
266
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
267
+
268
+ def __init__(self, size, eps=1e-12):
269
+ super().__init__()
270
+ self.weight = nn.Parameter(torch.ones(size))
271
+ self.bias = nn.Parameter(torch.zeros(size))
272
+ self.variance_epsilon = eps
273
+
274
+ def forward(self, hidden_states):
275
+ input_type = hidden_states.dtype
276
+ hidden_states = hidden_states.float()
277
+ mean = hidden_states.mean(-1, keepdim=True)
278
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
279
+ hidden_states = (hidden_states - mean) / torch.sqrt(
280
+ variance + self.variance_epsilon
281
+ )
282
+ hidden_states = hidden_states.to(input_type)
283
+ y = self.weight * hidden_states + self.bias
284
+ return y
285
+
286
+
287
+ class DisentangledSelfAttention(nn.Module):
288
+
289
+ def __init__(self, config: ProSSTConfig):
290
+ super().__init__()
291
+ self.config = config
292
+ self.num_attention_heads = config.num_attention_heads
293
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
294
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
295
+ # Q, K, V projection layers
296
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
297
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
298
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
299
+ # AA->SS, AA->POS, SS->AA, POS->AA and AA->AA attention layers
300
+ if config.pos_att_type is not None:
301
+ self.pos_att_type = config.pos_att_type
302
+ else:
303
+ self.pos_att_type = []
304
+ self.relative_attention = config.relative_attention
305
+ self.position_embedding_type = config.position_embedding_type
306
+
307
+ if self.position_embedding_type == "rotary":
308
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
309
+ if self.relative_attention:
310
+ if "aa2ss" in self.pos_att_type:
311
+ self.ss_proj = nn.Linear(
312
+ config.hidden_size, self.all_head_size, bias=False
313
+ )
314
+ if "ss2aa" in self.pos_att_type:
315
+ self.ss_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
316
+
317
+ elif self.position_embedding_type == "relative":
318
+ if self.relative_attention:
319
+ self.max_relative_positions = config.max_relative_positions
320
+ self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
321
+ # AA2POS
322
+ if "aa2pos" in self.pos_att_type:
323
+ self.pos_proj = nn.Linear(
324
+ config.hidden_size, self.all_head_size, bias=False
325
+ ) # Key
326
+ # POS2AA
327
+ if "pos2aa" in self.pos_att_type:
328
+ self.pos_q_proj = nn.Linear(
329
+ config.hidden_size, self.all_head_size
330
+ ) # Query
331
+ # AA2SS
332
+ if "aa2ss" in self.pos_att_type:
333
+ self.ss_proj = nn.Linear(
334
+ config.hidden_size, self.all_head_size, bias=False
335
+ )
336
+ # SS2AA
337
+ if "ss2aa" in self.pos_att_type:
338
+ self.ss_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
339
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
340
+
341
+ def transpose_for_scores(self, x):
342
+ # x [batch_size, seq_len, all_head_size]
343
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
344
+ # x [batch_size, seq_len, num_attention_heads, attention_head_size]
345
+ x = x.view(new_x_shape)
346
+ # x [batch_size, num_attention_heads, seq_len, attention_head_size]
347
+ return x.permute(0, 2, 1, 3)
348
+
349
+ def forward(
350
+ self,
351
+ hidden_states,
352
+ attention_mask,
353
+ ss_hidden_states=None,
354
+ relative_pos=None,
355
+ rel_embeddings=None,
356
+ output_attentions=False,
357
+ ):
358
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
359
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
360
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
361
+ if self.position_embedding_type == "rotary":
362
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
363
+ rel_att = None
364
+ scale_factor = 1 + len(self.pos_att_type)
365
+ scale = torch.sqrt(
366
+ torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor
367
+ )
368
+ query_layer = query_layer / scale.to(dtype=query_layer.dtype)
369
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
370
+ if self.relative_attention:
371
+ if self.position_embedding_type == "relative":
372
+ rel_embeddings = self.pos_dropout(rel_embeddings)
373
+ rel_att, output_attentions_dict = self.disentangled_att_bias(
374
+ query_layer,
375
+ key_layer,
376
+ relative_pos,
377
+ rel_embeddings,
378
+ scale_factor,
379
+ ss_hidden_states,
380
+ )
381
+ output_attentions_dict["aa2aa"] = attention_scores
382
+ attention_scores = attention_scores + rel_att
383
+
384
+ rmask = ~(attention_mask.to(torch.bool))
385
+ attention_probs = attention_scores.masked_fill(
386
+ rmask, torch.finfo(attention_scores.dtype).min
387
+ )
388
+ attention_probs = torch.softmax(attention_probs, -1)
389
+ # attention_probs = attention_probs.masked_fill(rmask, 0.0)
390
+ attention_probs = self.dropout(attention_probs)
391
+ context_layer = torch.matmul(attention_probs, value_layer)
392
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
393
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
394
+ context_layer = context_layer.view(new_context_layer_shape)
395
+ if output_attentions:
396
+ if self.relative_attention:
397
+ return (context_layer, output_attentions_dict)
398
+ else:
399
+ return (context_layer, attention_probs)
400
+ else:
401
+ return context_layer
402
+
403
+ def disentangled_att_bias(
404
+ self,
405
+ query_layer,
406
+ key_layer,
407
+ relative_pos,
408
+ rel_embeddings,
409
+ scale_factor,
410
+ ss_hidden_states,
411
+ ):
412
+ disentangled_attentions = {}
413
+ if self.position_embedding_type == "relative":
414
+ if relative_pos.dim() == 2:
415
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
416
+ elif relative_pos.dim() == 3:
417
+ relative_pos = relative_pos.unsqueeze(1)
418
+ # bxhxqxk
419
+ elif relative_pos.dim() != 4:
420
+ raise ValueError(
421
+ f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}"
422
+ )
423
+
424
+ att_span = min(
425
+ max(query_layer.size(-2), key_layer.size(-2)),
426
+ self.max_relative_positions,
427
+ )
428
+ relative_pos = relative_pos.long().to(query_layer.device)
429
+ rel_embeddings = rel_embeddings[
430
+ self.max_relative_positions
431
+ - att_span : self.max_relative_positions
432
+ + att_span,
433
+ :,
434
+ ].unsqueeze(0)
435
+ score = 0
436
+ if "aa2pos" in self.pos_att_type:
437
+ pos_key_layer = self.pos_proj(rel_embeddings)
438
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
439
+ aa2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
440
+ aa2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
441
+ aa2p_att = torch.gather(
442
+ aa2p_att,
443
+ dim=-1,
444
+ index=c2p_dynamic_expand(aa2p_pos, query_layer, relative_pos),
445
+ )
446
+ score += aa2p_att
447
+ disentangled_attentions["aa2pos"] = aa2p_att
448
+ if "pos2aa" in self.pos_att_type:
449
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
450
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
451
+ pos_query_layer /= torch.sqrt(
452
+ torch.tensor(pos_query_layer.size(-1), dtype=torch.float)
453
+ * scale_factor
454
+ )
455
+ if query_layer.size(-2) != key_layer.size(-2):
456
+ r_pos = build_relative_position(
457
+ key_layer.size(-2), key_layer.size(-2), query_layer.device
458
+ )
459
+ else:
460
+ r_pos = relative_pos
461
+ p2aa_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
462
+ p2aa_att = torch.matmul(
463
+ key_layer,
464
+ pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype),
465
+ )
466
+ p2aa_att = torch.gather(
467
+ p2aa_att,
468
+ dim=-1,
469
+ index=p2c_dynamic_expand(p2aa_pos, query_layer, key_layer),
470
+ ).transpose(-1, -2)
471
+
472
+ if query_layer.size(-2) != key_layer.size(-2):
473
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
474
+ p2aa_att = torch.gather(
475
+ p2aa_att,
476
+ dim=-2,
477
+ index=pos_dynamic_expand(pos_index, p2aa_att, key_layer),
478
+ )
479
+ score += p2aa_att
480
+ disentangled_attentions["pos2aa"] = p2aa_att
481
+ # content -> structure
482
+ if "aa2ss" in self.pos_att_type:
483
+ assert ss_hidden_states is not None
484
+ ss_key_layer = self.ss_proj(ss_hidden_states)
485
+ ss_key_layer = self.transpose_for_scores(ss_key_layer)
486
+ # [batch_size, num_attention_heads, seq_len, seq_len]
487
+ aa2ss_att = torch.matmul(query_layer, ss_key_layer.transpose(-1, -2))
488
+ score += aa2ss_att
489
+ disentangled_attentions["aa2ss"] = aa2ss_att
490
+ if "ss2aa" in self.pos_att_type:
491
+ assert ss_hidden_states is not None
492
+ ss_query_layer = self.ss_q_proj(ss_hidden_states)
493
+ ss_query_layer = self.transpose_for_scores(ss_query_layer)
494
+ ss_query_layer /= torch.sqrt(
495
+ torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
496
+ * scale_factor
497
+ )
498
+ ss2aa_att = torch.matmul(
499
+ key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
500
+ )
501
+ score += ss2aa_att
502
+ disentangled_attentions["ss2aa"] = ss2aa_att
503
+ return score, disentangled_attentions
504
+ elif self.position_embedding_type == "rotary":
505
+ score = 0
506
+ if "aa2ss" in self.pos_att_type:
507
+ assert ss_hidden_states is not None
508
+ ss_key_layer = self.ss_proj(ss_hidden_states)
509
+ ss_key_layer = self.transpose_for_scores(ss_key_layer)
510
+ aa2ss_att = torch.matmul(query_layer, ss_key_layer.transpose(-1, -2))
511
+ score += aa2ss_att
512
+ disentangled_attentions["aa2ss"] = aa2ss_att
513
+ if "ss2aa" in self.pos_att_type:
514
+ assert ss_hidden_states is not None
515
+ ss_query_layer = self.ss_q_proj(ss_hidden_states)
516
+ ss_query_layer = self.transpose_for_scores(ss_query_layer)
517
+ ss_query_layer /= torch.sqrt(
518
+ torch.tensor(ss_query_layer.size(-1), dtype=torch.float)
519
+ * scale_factor
520
+ )
521
+ ss2aa_att = torch.matmul(
522
+ key_layer, query_layer.transpose(-1, -2).to(dtype=key_layer.dtype)
523
+ )
524
+ score += ss2aa_att
525
+ disentangled_attentions["ss2aa"] = ss2aa_att
526
+ return score, disentangled_attentions
527
+
528
+
529
+ class ProSSTSelfOutput(nn.Module):
530
+ def __init__(self, config):
531
+ super().__init__()
532
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
533
+ self.LayerNorm = ProSSTLayerNorm(config.hidden_size, config.layer_norm_eps)
534
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
535
+
536
+ def forward(self, hidden_states, input_tensor):
537
+ hidden_states = self.dense(hidden_states)
538
+ hidden_states = self.dropout(hidden_states)
539
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
540
+ return hidden_states
541
+
542
+
543
+ class ProSSTAttention(nn.Module):
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.config = config
547
+ self.self = DisentangledSelfAttention(config)
548
+ self.output = ProSSTSelfOutput(config)
549
+
550
+ def forward(
551
+ self,
552
+ hidden_states,
553
+ attention_mask,
554
+ ss_hidden_states=None,
555
+ relative_pos=None,
556
+ rel_embeddings=None,
557
+ output_attentions=False,
558
+ ):
559
+ self_output = self.self(
560
+ hidden_states=hidden_states,
561
+ attention_mask=attention_mask,
562
+ output_attentions=output_attentions,
563
+ relative_pos=relative_pos,
564
+ rel_embeddings=rel_embeddings,
565
+ ss_hidden_states=ss_hidden_states
566
+ )
567
+ if output_attentions:
568
+ self_output, att_matrix = self_output
569
+ attention_output = self.output(self_output, hidden_states)
570
+ if output_attentions:
571
+ return (attention_output, att_matrix)
572
+ else:
573
+ return attention_output
574
+
575
+
576
+ class ProSSTIntermediate(nn.Module):
577
+ def __init__(self, config):
578
+ super().__init__()
579
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
580
+ if isinstance(config.hidden_act, str):
581
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
582
+ else:
583
+ self.intermediate_act_fn = config.hidden_act
584
+
585
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
586
+ hidden_states = self.dense(hidden_states)
587
+ hidden_states = self.intermediate_act_fn(hidden_states)
588
+ return hidden_states
589
+
590
+
591
+ class ProSSTOutput(nn.Module):
592
+ def __init__(self, config):
593
+ super().__init__()
594
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
595
+ self.LayerNorm = ProSSTLayerNorm(config.hidden_size, config.layer_norm_eps)
596
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
597
+ self.config = config
598
+
599
+ def forward(self, hidden_states, input_tensor):
600
+ hidden_states = self.dense(hidden_states)
601
+ hidden_states = self.dropout(hidden_states)
602
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
603
+ return hidden_states
604
+
605
+
606
+ class ProSSTLayer(nn.Module):
607
+ def __init__(self, config):
608
+ super().__init__()
609
+ self.config = config
610
+ self.attention = ProSSTAttention(config)
611
+ self.intermediate = ProSSTIntermediate(config)
612
+ self.output = ProSSTOutput(config)
613
+
614
+ def forward(
615
+ self,
616
+ hidden_states,
617
+ attention_mask,
618
+ ss_hidden_states=None,
619
+ relative_pos=None,
620
+ rel_embeddings=None,
621
+ output_attentions=False,
622
+ ):
623
+ attention_output = self.attention(
624
+ hidden_states,
625
+ attention_mask,
626
+ output_attentions=output_attentions,
627
+ relative_pos=relative_pos,
628
+ rel_embeddings=rel_embeddings,
629
+ ss_hidden_states=ss_hidden_states,
630
+ )
631
+ if output_attentions:
632
+ attention_output, att_matrix = attention_output
633
+ intermediate_output = self.intermediate(attention_output)
634
+ layer_output = self.output(intermediate_output, attention_output)
635
+ if output_attentions:
636
+ return (layer_output, att_matrix)
637
+ else:
638
+ return layer_output
639
+
640
+
641
+ class ProSSTEncoder(nn.Module):
642
+ """Modified BertEncoder with relative position bias support"""
643
+
644
+ def __init__(self, config):
645
+ super().__init__()
646
+ self.layer = nn.ModuleList(
647
+ [ProSSTLayer(config) for _ in range(config.num_hidden_layers)]
648
+ )
649
+ self.relative_attention = config.relative_attention
650
+ if self.relative_attention:
651
+ self.max_relative_positions = config.max_relative_positions
652
+ self.rel_embeddings = nn.Embedding(
653
+ self.max_relative_positions * 2, config.hidden_size
654
+ )
655
+
656
+ def get_attention_mask(self, attention_mask):
657
+ if attention_mask.dim() <= 2:
658
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
659
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(
660
+ -2
661
+ ).unsqueeze(-1)
662
+ elif attention_mask.dim() == 3:
663
+ attention_mask = attention_mask.unsqueeze(1)
664
+
665
+ return attention_mask
666
+
667
+ def get_rel_pos(self, hidden_states):
668
+ query_size = hidden_states.size(-2)
669
+ key_size = hidden_states.size(-2)
670
+ relative_pos = build_relative_position(
671
+ query_size, key_size, hidden_states.device
672
+ )
673
+ return relative_pos
674
+
675
+ def forward(
676
+ self,
677
+ hidden_states,
678
+ attention_mask,
679
+ ss_hidden_states=None,
680
+ output_hidden_states=False,
681
+ output_attentions=False,
682
+ ) -> BaseModelOutput:
683
+ attention_mask = self.get_attention_mask(attention_mask)
684
+ relative_pos = self.get_rel_pos(hidden_states)
685
+ all_hidden_states = []
686
+ all_attentions = []
687
+ rel_embeddings = self.rel_embeddings.weight
688
+ for i, layer_module in enumerate(self.layer):
689
+ if output_hidden_states:
690
+ all_hidden_states.append(hidden_states)
691
+ hidden_states = layer_module(
692
+ hidden_states,
693
+ attention_mask,
694
+ relative_pos=relative_pos,
695
+ rel_embeddings=rel_embeddings,
696
+ output_attentions=output_attentions,
697
+ ss_hidden_states=ss_hidden_states,
698
+ )
699
+ if output_attentions:
700
+ hidden_states, att_matrix = hidden_states
701
+ all_attentions.append(att_matrix)
702
+ if output_hidden_states:
703
+ all_hidden_states.append(hidden_states)
704
+
705
+ return BaseModelOutput(
706
+ last_hidden_state=hidden_states,
707
+ hidden_states=all_hidden_states,
708
+ attentions=all_attentions,
709
+ )
710
+
711
+
712
+ class ProSSTEmbeddings(nn.Module):
713
+ """Construct the embeddings from word, position and token_type embeddings."""
714
+
715
+ def __init__(self, config):
716
+ super().__init__()
717
+ self.config = config
718
+ self.pad_token_id = config.pad_token_id
719
+ self.embedding_size = config.hidden_size
720
+ self.word_embeddings = nn.Embedding(
721
+ config.vocab_size, self.embedding_size, padding_idx=self.pad_token_id
722
+ )
723
+ self.LayerNorm = ProSSTLayerNorm(config.hidden_size, config.layer_norm_eps)
724
+
725
+ # 绝对位置编码
726
+ self.position_biased_input = config.position_biased_input
727
+ if not self.position_biased_input:
728
+ self.position_embeddings = None
729
+ else:
730
+ self.position_embeddings = nn.Embedding(
731
+ config.max_position_embeddings,
732
+ self.embedding_size,
733
+ padding_idx=self.pad_token_id,
734
+ )
735
+
736
+ # Token-type embeddings
737
+ if config.type_vocab_size > 0:
738
+ self.token_type_embeddings = nn.Embedding(
739
+ config.type_vocab_size, self.embedding_size
740
+ )
741
+
742
+ # SS embeddings
743
+ if config.ss_vocab_size > 0:
744
+ self.ss_embeddings = nn.Embedding(config.ss_vocab_size, self.embedding_size)
745
+ self.ss_layer_norm = ProSSTLayerNorm(
746
+ config.hidden_size, config.layer_norm_eps
747
+ )
748
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
749
+ # PLDDT Embeddings
750
+ self.rbf_16_fn = partial(rbf, v_min=0.0, v_max=1.0, n_bins=16)
751
+ self.plddt_proj = nn.Linear(16, self.embedding_size)
752
+ self.avg_plddt_proj = nn.Linear(16, self.embedding_size)
753
+
754
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
755
+ if self.position_biased_input:
756
+ self.register_buffer(
757
+ "position_ids",
758
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
759
+ persistent=False,
760
+ )
761
+
762
+ def forward(
763
+ self,
764
+ input_ids,
765
+ attention_mask,
766
+ ss_input_ids=None,
767
+ plddt=None,
768
+ avg_plddt=None,
769
+ token_type_ids=None,
770
+ ):
771
+ inputs_embeds = self.word_embeddings(input_ids)
772
+ inputs_embeds = 0.1 * inputs_embeds + 0.9 * inputs_embeds.detach()
773
+ embeddings = inputs_embeds
774
+
775
+ if self.position_biased_input:
776
+ input_shape = input_ids.size()
777
+ seq_length = input_shape[1]
778
+ position_ids = self.position_ids[:, :seq_length]
779
+ if seq_length > position_ids.size(1):
780
+ zero_padding = (
781
+ torch.zeros(
782
+ (input_shape[0], seq_length - position_ids.size(1)),
783
+ dtype=torch.long,
784
+ device=position_ids.device,
785
+ )
786
+ + 2047
787
+ )
788
+ position_ids = torch.cat([position_ids, zero_padding], dim=1)
789
+ position_embeddings = self.position_embeddings(position_ids.long())
790
+ embeddings += position_embeddings
791
+
792
+ if self.config.type_vocab_size > 0:
793
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
794
+ embeddings += token_type_embeddings
795
+
796
+ embeddings = self.LayerNorm(embeddings)
797
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
798
+ embeddings = self.dropout(embeddings)
799
+
800
+ if self.config.ss_vocab_size > 0:
801
+ ss_embeddings = self.ss_embeddings(ss_input_ids)
802
+ ss_embeddings = ss_embeddings * attention_mask.unsqueeze(-1)
803
+ # plddt
804
+ plddt_embedding = self.plddt_proj(self.rbf_16_fn(plddt))
805
+ avg_plddt_embedding = self.plddt_proj(self.rbf_16_fn(avg_plddt))
806
+ ss_embeddings = ss_embeddings + plddt_embedding + avg_plddt_embedding
807
+ ss_embeddings = self.ss_layer_norm(ss_embeddings)
808
+ ss_embeddings = ss_embeddings * 0.1 + ss_embeddings.detach() * 0.9
809
+ ss_embeddings = self.dropout(ss_embeddings)
810
+ return embeddings, ss_embeddings
811
+
812
+ return embeddings, None
813
+
814
+
815
+ class ProSSTPreTrainedModel(PreTrainedModel):
816
+ """
817
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
818
+ models.
819
+ """
820
+
821
+ config_class = ProSSTConfig
822
+ base_model_prefix = "ProSST"
823
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
824
+ supports_gradient_checkpointing = True
825
+
826
+ def _init_weights(self, module):
827
+ """Initialize the weights."""
828
+ if isinstance(module, nn.Linear):
829
+ # Slightly different from the TF version which uses truncated_normal for initialization
830
+ # cf https://github.com/pytorch/pytorch/pull/5617
831
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
832
+ if module.bias is not None:
833
+ module.bias.data.zero_()
834
+ elif isinstance(module, nn.Embedding):
835
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
836
+ if module.padding_idx is not None:
837
+ module.weight.data[module.padding_idx].zero_()
838
+
839
+ def _set_gradient_checkpointing(self, module, value=False):
840
+ if isinstance(module, ProSSTEncoder):
841
+ module.gradient_checkpointing = value
842
+
843
+
844
+ class ProSSTModel(ProSSTPreTrainedModel):
845
+ def __init__(self, config):
846
+ super().__init__(config)
847
+ self.config = config
848
+ self.embeddings = ProSSTEmbeddings(config)
849
+ self.encoder = ProSSTEncoder(config)
850
+ self.post_init()
851
+
852
+ def forward(
853
+ self,
854
+ input_ids,
855
+ attention_mask,
856
+ ss_input_ids=None,
857
+ plddt=None,
858
+ avg_plddt=None,
859
+ token_type_ids=None,
860
+ output_attentions=False,
861
+ output_hidden_states=False,
862
+ ) -> BaseModelOutput:
863
+ embedding_output, ss_embeddings = self.embeddings(
864
+ input_ids=input_ids,
865
+ attention_mask=attention_mask,
866
+ ss_input_ids=ss_input_ids,
867
+ plddt=plddt,
868
+ avg_plddt=avg_plddt,
869
+ token_type_ids=token_type_ids,
870
+ )
871
+ encoder_outputs = self.encoder(
872
+ embedding_output,
873
+ attention_mask,
874
+ output_hidden_states=output_hidden_states,
875
+ output_attentions=output_attentions,
876
+ ss_hidden_states=ss_embeddings,
877
+ )
878
+ return BaseModelOutput(
879
+ last_hidden_state=encoder_outputs.last_hidden_state,
880
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
881
+ attentions=encoder_outputs.attentions if output_attentions else None,
882
+ )
883
+
884
+
885
+ class ProSSTPredictionHeadTransform(nn.Module):
886
+ def __init__(self, config):
887
+ super().__init__()
888
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
889
+ self.dense = nn.Linear(config.hidden_size, self.embedding_size)
890
+ if isinstance(config.hidden_act, str):
891
+ self.transform_act_fn = ACT2FN[config.hidden_act]
892
+ else:
893
+ self.transform_act_fn = config.hidden_act
894
+ self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
895
+
896
+ def forward(self, hidden_states):
897
+ hidden_states = self.dense(hidden_states)
898
+ hidden_states = self.transform_act_fn(hidden_states)
899
+ hidden_states = self.LayerNorm(hidden_states)
900
+ return hidden_states
901
+
902
+
903
+ class ProSSTLMPredictionHead(nn.Module):
904
+ def __init__(self, config):
905
+ super().__init__()
906
+ self.config = config
907
+ self.transform = ProSSTPredictionHeadTransform(config)
908
+ self.embedding_size = config.hidden_size
909
+ self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
910
+
911
+ def forward(self, hidden_states):
912
+ hidden_states = self.transform(hidden_states)
913
+ hidden_states = self.decoder(hidden_states)
914
+ return hidden_states
915
+
916
+
917
+ class ProSSTOnlyMLMHead(nn.Module):
918
+ def __init__(self, config):
919
+ super().__init__()
920
+ self.predictions = ProSSTLMPredictionHead(config)
921
+
922
+ def forward(self, sequence_output):
923
+ prediction_scores = self.predictions(sequence_output)
924
+ return prediction_scores
925
+
926
+
927
+ class ProSSTPreTrainedModel(PreTrainedModel):
928
+ """
929
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
930
+ models.
931
+ """
932
+
933
+ config_class = ProSSTConfig
934
+ base_model_prefix = "ProSST"
935
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
936
+ supports_gradient_checkpointing = True
937
+
938
+ def _init_weights(self, module):
939
+ """Initialize the weights."""
940
+ if isinstance(module, nn.Linear):
941
+ # Slightly different from the TF version which uses truncated_normal for initialization
942
+ # cf https://github.com/pytorch/pytorch/pull/5617
943
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
944
+ if module.bias is not None:
945
+ module.bias.data.zero_()
946
+ elif isinstance(module, nn.Embedding):
947
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
948
+ if module.padding_idx is not None:
949
+ module.weight.data[module.padding_idx].zero_()
950
+
951
+ def _set_gradient_checkpointing(self, module, value=False):
952
+ if isinstance(module, ProSSTEncoder):
953
+ module.gradient_checkpointing = value
954
+
955
+
956
+ class ProSSTForMaskedLM(ProSSTPreTrainedModel):
957
+ _tied_weights_keys = [
958
+ "cls.predictions.decoder.weight",
959
+ "cls.predictions.decoder.bias",
960
+ ]
961
+
962
+ def __init__(self, config):
963
+ super().__init__(config)
964
+ self.prosst = ProSSTModel(config)
965
+ self.cls = ProSSTOnlyMLMHead(config)
966
+ self.post_init()
967
+
968
+ def forward(
969
+ self,
970
+ input_ids,
971
+ attention_mask,
972
+ ss_input_ids=None,
973
+ plddt=None,
974
+ avg_plddt=None,
975
+ token_type_ids=None,
976
+ labels=None,
977
+ output_attentions=False,
978
+ output_hidden_states=False,
979
+ ) -> MaskedLMOutput:
980
+ outputs = self.prosst(
981
+ input_ids=input_ids,
982
+ attention_mask=attention_mask,
983
+ ss_input_ids=ss_input_ids,
984
+ plddt=plddt,
985
+ avg_plddt=avg_plddt,
986
+ token_type_ids=token_type_ids,
987
+ output_attentions=output_attentions,
988
+ output_hidden_states=output_hidden_states,
989
+ )
990
+ sequence_output = outputs[0]
991
+ prediction_scores = self.cls(sequence_output)
992
+ masked_lm_loss = None
993
+ if labels is not None:
994
+ loss_fct = CrossEntropyLoss(ignore_index=0) # -100 index = padding token
995
+ masked_lm_loss = loss_fct(
996
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
997
+ )
998
+ else:
999
+ masked_lm_loss = None
1000
+ return MaskedLMOutput(
1001
+ loss=masked_lm_loss,
1002
+ logits=prediction_scores,
1003
+ hidden_states=outputs.hidden_states,
1004
+ attentions=outputs.attentions,
1005
+ )
1006
+
1007
+
1008
+ class ProSSTForSequenceClassification(ProSSTPreTrainedModel):
1009
+ def __init__(self, config):
1010
+ super().__init__(config)
1011
+
1012
+ num_labels = getattr(config, "num_labels", 2)
1013
+ self.num_labels = num_labels
1014
+ self.scale_hidden = getattr(config, "scale_hidden", 1)
1015
+ self.prosst = ProSSTModel(config)
1016
+ self.pooler = ContextPooler(config)
1017
+ output_dim = self.pooler.output_dim * self.scale_hidden
1018
+
1019
+ self.classifier = nn.Linear(output_dim, num_labels)
1020
+ drop_out = getattr(config, "cls_dropout", None)
1021
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1022
+ self.dropout = nn.Dropout(drop_out)
1023
+
1024
+ # Initialize weights and apply final processing
1025
+ self.post_init()
1026
+
1027
+ def get_input_embeddings(self):
1028
+ return self.prosst.get_input_embeddings()
1029
+
1030
+ def set_input_embeddings(self, new_embeddings):
1031
+ self.prosst.set_input_embeddings(new_embeddings)
1032
+
1033
+ def forward(
1034
+ self,
1035
+ input_ids: Optional[torch.Tensor] = None,
1036
+ ss_input_ids: Optional[torch.Tensor] = None,
1037
+ attention_mask: Optional[torch.Tensor] = None,
1038
+ token_type_ids: Optional[torch.Tensor] = None,
1039
+ position_ids: Optional[torch.Tensor] = None,
1040
+ inputs_embeds: Optional[torch.Tensor] = None,
1041
+ labels: Optional[torch.Tensor] = None,
1042
+ output_attentions: Optional[bool] = None,
1043
+ output_hidden_states: Optional[bool] = None,
1044
+ return_dict: Optional[bool] = None,
1045
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1046
+ r"""
1047
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1048
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1049
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1050
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1051
+ """
1052
+ return_dict = (
1053
+ return_dict if return_dict is not None else self.config.use_return_dict
1054
+ )
1055
+
1056
+ outputs = self.prosst(
1057
+ input_ids,
1058
+ ss_input_ids=ss_input_ids,
1059
+ token_type_ids=token_type_ids,
1060
+ attention_mask=attention_mask,
1061
+ position_ids=position_ids,
1062
+ inputs_embeds=inputs_embeds,
1063
+ output_attentions=output_attentions,
1064
+ output_hidden_states=output_hidden_states,
1065
+ return_dict=return_dict,
1066
+ )
1067
+
1068
+ encoder_layer = outputs[0]
1069
+ pooled_output = self.pooler(encoder_layer, attention_mask)
1070
+ pooled_output = self.dropout(pooled_output)
1071
+ logits = self.classifier(pooled_output)
1072
+
1073
+ loss = None
1074
+ if labels is not None:
1075
+ if self.config.problem_type is None:
1076
+ if self.num_labels == 1:
1077
+ # regression task
1078
+ loss_fn = nn.MSELoss()
1079
+ logits = logits.view(-1).to(labels.dtype)
1080
+ loss = loss_fn(logits, labels.view(-1))
1081
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1082
+ label_index = (labels >= 0).nonzero()
1083
+ labels = labels.long()
1084
+ if label_index.size(0) > 0:
1085
+ labeled_logits = torch.gather(
1086
+ logits,
1087
+ 0,
1088
+ label_index.expand(label_index.size(0), logits.size(1)),
1089
+ )
1090
+ labels = torch.gather(labels, 0, label_index.view(-1))
1091
+ loss_fct = CrossEntropyLoss()
1092
+ loss = loss_fct(
1093
+ labeled_logits.view(-1, self.num_labels).float(),
1094
+ labels.view(-1),
1095
+ )
1096
+ else:
1097
+ loss = torch.tensor(0).to(logits)
1098
+ else:
1099
+ log_softmax = nn.LogSoftmax(-1)
1100
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1101
+ elif self.config.problem_type == "regression":
1102
+ loss_fct = MSELoss()
1103
+ if self.num_labels == 1:
1104
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1105
+ else:
1106
+ loss = loss_fct(logits, labels)
1107
+ elif self.config.problem_type == "binary_classification":
1108
+ loss_fct = BCEWithLogitsLoss()
1109
+ loss = loss_fct(logits.squeeze(), labels.squeeze().to(logits.dtype))
1110
+ elif self.config.problem_type == "single_label_classification":
1111
+ loss_fct = CrossEntropyLoss()
1112
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1113
+ elif self.config.problem_type == "multi_label_classification":
1114
+ loss_fct = BCEWithLogitsLoss()
1115
+ loss = loss_fct(logits, labels.to(logits.dtype))
1116
+ if not return_dict:
1117
+ output = (logits,) + outputs[1:]
1118
+ return ((loss,) + output) if loss is not None else output
1119
+
1120
+ return SequenceClassifierOutput(
1121
+ loss=loss,
1122
+ logits=logits,
1123
+ hidden_states=outputs.hidden_states,
1124
+ attentions=outputs.attentions,
1125
+ )
1126
+
1127
+
1128
+ class ProSSTForTokenClassification(ProSSTPreTrainedModel):
1129
+ def __init__(self, config):
1130
+ super().__init__(config)
1131
+ self.num_labels = config.num_labels
1132
+
1133
+ self.prosst = ProSSTModel(config)
1134
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1135
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1136
+
1137
+ # Initialize weights and apply final processing
1138
+ self.post_init()
1139
+
1140
+ def forward(
1141
+ self,
1142
+ input_ids: Optional[torch.Tensor] = None,
1143
+ attention_mask: Optional[torch.Tensor] = None,
1144
+ token_type_ids: Optional[torch.Tensor] = None,
1145
+ position_ids: Optional[torch.Tensor] = None,
1146
+ inputs_embeds: Optional[torch.Tensor] = None,
1147
+ labels: Optional[torch.Tensor] = None,
1148
+ output_attentions: Optional[bool] = None,
1149
+ output_hidden_states: Optional[bool] = None,
1150
+ return_dict: Optional[bool] = None,
1151
+ ) -> Union[Tuple, TokenClassifierOutput]:
1152
+ r"""
1153
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1154
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1155
+ """
1156
+ return_dict = (
1157
+ return_dict if return_dict is not None else self.config.use_return_dict
1158
+ )
1159
+
1160
+ outputs = self.prosst(
1161
+ input_ids,
1162
+ attention_mask=attention_mask,
1163
+ token_type_ids=token_type_ids,
1164
+ position_ids=position_ids,
1165
+ inputs_embeds=inputs_embeds,
1166
+ output_attentions=output_attentions,
1167
+ output_hidden_states=output_hidden_states,
1168
+ return_dict=return_dict,
1169
+ )
1170
+
1171
+ sequence_output = outputs[0]
1172
+
1173
+ sequence_output = self.dropout(sequence_output)
1174
+ logits = self.classifier(sequence_output)
1175
+
1176
+ loss = None
1177
+ if labels is not None:
1178
+ loss_fct = CrossEntropyLoss()
1179
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1180
+
1181
+ if not return_dict:
1182
+ output = (logits,) + outputs[1:]
1183
+ return ((loss,) + output) if loss is not None else output
1184
+
1185
+ return TokenClassifierOutput(
1186
+ loss=loss,
1187
+ logits=logits,
1188
+ hidden_states=outputs.hidden_states,
1189
+ attentions=outputs.attentions,
1190
+ )
1191
+
1192
+
1193
+ ProSSTModel.register_for_auto_class("AutoModel")
1194
+ ProSSTForMaskedLM.register_for_auto_class("AutoModelForMaskedLM")
1195
+ ProSSTForSequenceClassification.register_for_auto_class(
1196
+ "AutoModelForSequenceClassification"
1197
+ )
1198
+ ProSSTForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")