徐俊德 commited on
Commit
c525dff
1 Parent(s): 53e5e8d
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "ProGenForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_progen.ProGenConfig",
9
+ "AutoModelForCausalLM": "modeling_InstructProGen.ProGenForCausalLM"
10
+ },
11
+ "attn_pdrop": 0.0,
12
+ "bos_token_id": 1,
13
+ "embd_pdrop": 0.0,
14
+ "eos_token_id": 2,
15
+ "gradient_checkpointing": false,
16
+ "initializer_range": 0.02,
17
+ "layer_norm_epsilon": 1e-05,
18
+ "model_type": "progen",
19
+ "n_ctx": 2048,
20
+ "n_embd": 4096,
21
+ "n_head": 16,
22
+ "n_inner": null,
23
+ "n_layer": 32,
24
+ "n_positions": 1024,
25
+ "resid_pdrop": 0.0,
26
+ "rotary_dim": 64,
27
+ "scale_attn_weights": true,
28
+ "structure": {
29
+ "embedding_keys": [
30
+ "mpnn_emb"
31
+ ],
32
+ "max_seqlen": 512,
33
+ "n_queries": 256,
34
+ "num_heads": 16,
35
+ "output_dim": 4096,
36
+ "structure_emb_path_prefix": "./structure_embeddings",
37
+ "width": 1152
38
+ },
39
+ "summary_activation": null,
40
+ "summary_first_dropout": 0.1,
41
+ "summary_proj_to_labels": true,
42
+ "summary_type": "cls_index",
43
+ "summary_use_proj": true,
44
+ "task_specific_params": {
45
+ "text-generation": {
46
+ "do_sample": true,
47
+ "max_length": 50,
48
+ "temperature": 1.0
49
+ }
50
+ },
51
+ "tie_word_embeddings": false,
52
+ "tokenizer_type": "iPLMTokenizer",
53
+ "torch_dtype": "float16",
54
+ "transformers_version": "4.37.2",
55
+ "use_cache": true,
56
+ "vocab_size": 30
57
+ }
configuration_progen.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Modified configuration implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/configuration_gptj.py
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ProGenConfig(PretrainedConfig):
25
+ model_type = "progen"
26
+
27
+ def __init__(
28
+ self,
29
+ vocab_size=50400,
30
+ n_positions=2048,
31
+ n_ctx=2048,
32
+ n_embd=4096,
33
+ n_layer=28,
34
+ n_head=16,
35
+ rotary_dim=64,
36
+ n_inner=None,
37
+ activation_function="gelu_new",
38
+ resid_pdrop=0.0,
39
+ embd_pdrop=0.0,
40
+ attn_pdrop=0.0,
41
+ layer_norm_epsilon=1e-5,
42
+ initializer_range=0.02,
43
+ scale_attn_weights=True,
44
+ gradient_checkpointing=False,
45
+ use_cache=True,
46
+ bos_token_id=50256,
47
+ eos_token_id=50256,
48
+ tie_word_embeddings=False,
49
+ **kwargs
50
+ ):
51
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
52
+
53
+ self.vocab_size = vocab_size
54
+ self.n_ctx = n_ctx
55
+ self.n_positions = n_positions
56
+ self.n_embd = n_embd
57
+ self.n_layer = n_layer
58
+ self.n_head = n_head
59
+ self.n_inner = n_inner
60
+ self.rotary_dim = rotary_dim
61
+ self.activation_function = activation_function
62
+ self.resid_pdrop = resid_pdrop
63
+ self.embd_pdrop = embd_pdrop
64
+ self.attn_pdrop = attn_pdrop
65
+ self.layer_norm_epsilon = layer_norm_epsilon
66
+ self.initializer_range = initializer_range
67
+ self.gradient_checkpointing = gradient_checkpointing
68
+ self.scale_attn_weights = scale_attn_weights
69
+ self.use_cache = use_cache
70
+
71
+ self.bos_token_id = bos_token_id
72
+ self.eos_token_id = eos_token_id
73
+ self.tie_word_embeddings = tie_word_embeddings
74
+
75
+ @property
76
+ def max_position_embeddings(self):
77
+ return self.n_positions
78
+
79
+ @property
80
+ def hidden_size(self):
81
+ return self.n_embd
82
+
83
+ @property
84
+ def num_attention_heads(self):
85
+ return self.n_head
86
+
87
+ @property
88
+ def num_hidden_layers(self):
89
+ return self.n_layer
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.37.2"
6
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7aaf74230b98697cd267b2135f00c9d9a0badaceee5d53bc9a68f790b036310a
3
+ size 4980651258
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2de1bff26093acaa641f8b61c061d96acf090a4af47951bd768432a108326f37
3
+ size 4979372896
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:670f0b393f8a9f8560b40884120fa7fa09a09d46eda2f61fc30ecd8fbb1c1b17
3
+ size 3148588658
model.safetensors.index.json ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13079216252.0
4
+ },
5
+ "weight_map": {
6
+ "lm_head.bias": "model-00003-of-00003.safetensors",
7
+ "lm_head.weight": "model-00003-of-00003.safetensors",
8
+ "transformer.h.0.attn.bias": "model-00001-of-00003.safetensors",
9
+ "transformer.h.0.attn.masked_bias": "model-00001-of-00003.safetensors",
10
+ "transformer.h.0.attn.out_proj.weight": "model-00001-of-00003.safetensors",
11
+ "transformer.h.0.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
12
+ "transformer.h.0.ln_1.bias": "model-00001-of-00003.safetensors",
13
+ "transformer.h.0.ln_1.weight": "model-00001-of-00003.safetensors",
14
+ "transformer.h.0.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
15
+ "transformer.h.0.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
16
+ "transformer.h.0.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
17
+ "transformer.h.0.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
18
+ "transformer.h.1.attn.bias": "model-00001-of-00003.safetensors",
19
+ "transformer.h.1.attn.masked_bias": "model-00001-of-00003.safetensors",
20
+ "transformer.h.1.attn.out_proj.weight": "model-00001-of-00003.safetensors",
21
+ "transformer.h.1.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
22
+ "transformer.h.1.ln_1.bias": "model-00001-of-00003.safetensors",
23
+ "transformer.h.1.ln_1.weight": "model-00001-of-00003.safetensors",
24
+ "transformer.h.1.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
25
+ "transformer.h.1.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
26
+ "transformer.h.1.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
27
+ "transformer.h.1.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
28
+ "transformer.h.10.attn.bias": "model-00001-of-00003.safetensors",
29
+ "transformer.h.10.attn.masked_bias": "model-00001-of-00003.safetensors",
30
+ "transformer.h.10.attn.out_proj.weight": "model-00001-of-00003.safetensors",
31
+ "transformer.h.10.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
32
+ "transformer.h.10.ln_1.bias": "model-00001-of-00003.safetensors",
33
+ "transformer.h.10.ln_1.weight": "model-00001-of-00003.safetensors",
34
+ "transformer.h.10.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
35
+ "transformer.h.10.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
36
+ "transformer.h.10.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
37
+ "transformer.h.10.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
38
+ "transformer.h.11.attn.bias": "model-00001-of-00003.safetensors",
39
+ "transformer.h.11.attn.masked_bias": "model-00001-of-00003.safetensors",
40
+ "transformer.h.11.attn.out_proj.weight": "model-00001-of-00003.safetensors",
41
+ "transformer.h.11.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
42
+ "transformer.h.11.ln_1.bias": "model-00001-of-00003.safetensors",
43
+ "transformer.h.11.ln_1.weight": "model-00001-of-00003.safetensors",
44
+ "transformer.h.11.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
45
+ "transformer.h.11.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
46
+ "transformer.h.11.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
47
+ "transformer.h.11.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
48
+ "transformer.h.12.attn.bias": "model-00001-of-00003.safetensors",
49
+ "transformer.h.12.attn.masked_bias": "model-00001-of-00003.safetensors",
50
+ "transformer.h.12.attn.out_proj.weight": "model-00001-of-00003.safetensors",
51
+ "transformer.h.12.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
52
+ "transformer.h.12.ln_1.bias": "model-00001-of-00003.safetensors",
53
+ "transformer.h.12.ln_1.weight": "model-00001-of-00003.safetensors",
54
+ "transformer.h.12.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
55
+ "transformer.h.12.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
56
+ "transformer.h.12.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
57
+ "transformer.h.12.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
58
+ "transformer.h.13.attn.bias": "model-00002-of-00003.safetensors",
59
+ "transformer.h.13.attn.masked_bias": "model-00002-of-00003.safetensors",
60
+ "transformer.h.13.attn.out_proj.weight": "model-00002-of-00003.safetensors",
61
+ "transformer.h.13.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
62
+ "transformer.h.13.ln_1.bias": "model-00002-of-00003.safetensors",
63
+ "transformer.h.13.ln_1.weight": "model-00002-of-00003.safetensors",
64
+ "transformer.h.13.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
65
+ "transformer.h.13.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
66
+ "transformer.h.13.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
67
+ "transformer.h.13.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
68
+ "transformer.h.14.attn.bias": "model-00002-of-00003.safetensors",
69
+ "transformer.h.14.attn.masked_bias": "model-00002-of-00003.safetensors",
70
+ "transformer.h.14.attn.out_proj.weight": "model-00002-of-00003.safetensors",
71
+ "transformer.h.14.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
72
+ "transformer.h.14.ln_1.bias": "model-00002-of-00003.safetensors",
73
+ "transformer.h.14.ln_1.weight": "model-00002-of-00003.safetensors",
74
+ "transformer.h.14.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
75
+ "transformer.h.14.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
76
+ "transformer.h.14.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
77
+ "transformer.h.14.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
78
+ "transformer.h.15.attn.bias": "model-00002-of-00003.safetensors",
79
+ "transformer.h.15.attn.masked_bias": "model-00002-of-00003.safetensors",
80
+ "transformer.h.15.attn.out_proj.weight": "model-00002-of-00003.safetensors",
81
+ "transformer.h.15.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
82
+ "transformer.h.15.ln_1.bias": "model-00002-of-00003.safetensors",
83
+ "transformer.h.15.ln_1.weight": "model-00002-of-00003.safetensors",
84
+ "transformer.h.15.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
85
+ "transformer.h.15.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
86
+ "transformer.h.15.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
87
+ "transformer.h.15.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
88
+ "transformer.h.16.attn.bias": "model-00002-of-00003.safetensors",
89
+ "transformer.h.16.attn.masked_bias": "model-00002-of-00003.safetensors",
90
+ "transformer.h.16.attn.out_proj.weight": "model-00002-of-00003.safetensors",
91
+ "transformer.h.16.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
92
+ "transformer.h.16.ln_1.bias": "model-00002-of-00003.safetensors",
93
+ "transformer.h.16.ln_1.weight": "model-00002-of-00003.safetensors",
94
+ "transformer.h.16.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
95
+ "transformer.h.16.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
96
+ "transformer.h.16.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
97
+ "transformer.h.16.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
98
+ "transformer.h.17.attn.bias": "model-00002-of-00003.safetensors",
99
+ "transformer.h.17.attn.masked_bias": "model-00002-of-00003.safetensors",
100
+ "transformer.h.17.attn.out_proj.weight": "model-00002-of-00003.safetensors",
101
+ "transformer.h.17.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
102
+ "transformer.h.17.ln_1.bias": "model-00002-of-00003.safetensors",
103
+ "transformer.h.17.ln_1.weight": "model-00002-of-00003.safetensors",
104
+ "transformer.h.17.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
105
+ "transformer.h.17.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
106
+ "transformer.h.17.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
107
+ "transformer.h.17.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
108
+ "transformer.h.18.attn.bias": "model-00002-of-00003.safetensors",
109
+ "transformer.h.18.attn.masked_bias": "model-00002-of-00003.safetensors",
110
+ "transformer.h.18.attn.out_proj.weight": "model-00002-of-00003.safetensors",
111
+ "transformer.h.18.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
112
+ "transformer.h.18.ln_1.bias": "model-00002-of-00003.safetensors",
113
+ "transformer.h.18.ln_1.weight": "model-00002-of-00003.safetensors",
114
+ "transformer.h.18.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
115
+ "transformer.h.18.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
116
+ "transformer.h.18.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
117
+ "transformer.h.18.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
118
+ "transformer.h.19.attn.bias": "model-00002-of-00003.safetensors",
119
+ "transformer.h.19.attn.masked_bias": "model-00002-of-00003.safetensors",
120
+ "transformer.h.19.attn.out_proj.weight": "model-00002-of-00003.safetensors",
121
+ "transformer.h.19.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
122
+ "transformer.h.19.ln_1.bias": "model-00002-of-00003.safetensors",
123
+ "transformer.h.19.ln_1.weight": "model-00002-of-00003.safetensors",
124
+ "transformer.h.19.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
125
+ "transformer.h.19.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
126
+ "transformer.h.19.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
127
+ "transformer.h.19.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
128
+ "transformer.h.2.attn.bias": "model-00001-of-00003.safetensors",
129
+ "transformer.h.2.attn.masked_bias": "model-00001-of-00003.safetensors",
130
+ "transformer.h.2.attn.out_proj.weight": "model-00001-of-00003.safetensors",
131
+ "transformer.h.2.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
132
+ "transformer.h.2.ln_1.bias": "model-00001-of-00003.safetensors",
133
+ "transformer.h.2.ln_1.weight": "model-00001-of-00003.safetensors",
134
+ "transformer.h.2.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
135
+ "transformer.h.2.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
136
+ "transformer.h.2.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
137
+ "transformer.h.2.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
138
+ "transformer.h.20.attn.bias": "model-00002-of-00003.safetensors",
139
+ "transformer.h.20.attn.masked_bias": "model-00002-of-00003.safetensors",
140
+ "transformer.h.20.attn.out_proj.weight": "model-00002-of-00003.safetensors",
141
+ "transformer.h.20.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
142
+ "transformer.h.20.ln_1.bias": "model-00002-of-00003.safetensors",
143
+ "transformer.h.20.ln_1.weight": "model-00002-of-00003.safetensors",
144
+ "transformer.h.20.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
145
+ "transformer.h.20.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
146
+ "transformer.h.20.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
147
+ "transformer.h.20.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
148
+ "transformer.h.21.attn.bias": "model-00002-of-00003.safetensors",
149
+ "transformer.h.21.attn.masked_bias": "model-00002-of-00003.safetensors",
150
+ "transformer.h.21.attn.out_proj.weight": "model-00002-of-00003.safetensors",
151
+ "transformer.h.21.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
152
+ "transformer.h.21.ln_1.bias": "model-00002-of-00003.safetensors",
153
+ "transformer.h.21.ln_1.weight": "model-00002-of-00003.safetensors",
154
+ "transformer.h.21.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
155
+ "transformer.h.21.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
156
+ "transformer.h.21.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
157
+ "transformer.h.21.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
158
+ "transformer.h.22.attn.bias": "model-00002-of-00003.safetensors",
159
+ "transformer.h.22.attn.masked_bias": "model-00002-of-00003.safetensors",
160
+ "transformer.h.22.attn.out_proj.weight": "model-00002-of-00003.safetensors",
161
+ "transformer.h.22.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
162
+ "transformer.h.22.ln_1.bias": "model-00002-of-00003.safetensors",
163
+ "transformer.h.22.ln_1.weight": "model-00002-of-00003.safetensors",
164
+ "transformer.h.22.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
165
+ "transformer.h.22.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
166
+ "transformer.h.22.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
167
+ "transformer.h.22.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
168
+ "transformer.h.23.attn.bias": "model-00002-of-00003.safetensors",
169
+ "transformer.h.23.attn.masked_bias": "model-00002-of-00003.safetensors",
170
+ "transformer.h.23.attn.out_proj.weight": "model-00002-of-00003.safetensors",
171
+ "transformer.h.23.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
172
+ "transformer.h.23.ln_1.bias": "model-00002-of-00003.safetensors",
173
+ "transformer.h.23.ln_1.weight": "model-00002-of-00003.safetensors",
174
+ "transformer.h.23.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
175
+ "transformer.h.23.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
176
+ "transformer.h.23.mlp.fc_out.bias": "model-00002-of-00003.safetensors",
177
+ "transformer.h.23.mlp.fc_out.weight": "model-00002-of-00003.safetensors",
178
+ "transformer.h.24.attn.bias": "model-00002-of-00003.safetensors",
179
+ "transformer.h.24.attn.masked_bias": "model-00002-of-00003.safetensors",
180
+ "transformer.h.24.attn.out_proj.weight": "model-00002-of-00003.safetensors",
181
+ "transformer.h.24.attn.qkv_proj.weight": "model-00002-of-00003.safetensors",
182
+ "transformer.h.24.ln_1.bias": "model-00002-of-00003.safetensors",
183
+ "transformer.h.24.ln_1.weight": "model-00002-of-00003.safetensors",
184
+ "transformer.h.24.mlp.fc_in.bias": "model-00002-of-00003.safetensors",
185
+ "transformer.h.24.mlp.fc_in.weight": "model-00002-of-00003.safetensors",
186
+ "transformer.h.24.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
187
+ "transformer.h.24.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
188
+ "transformer.h.25.attn.bias": "model-00003-of-00003.safetensors",
189
+ "transformer.h.25.attn.masked_bias": "model-00003-of-00003.safetensors",
190
+ "transformer.h.25.attn.out_proj.weight": "model-00003-of-00003.safetensors",
191
+ "transformer.h.25.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
192
+ "transformer.h.25.ln_1.bias": "model-00003-of-00003.safetensors",
193
+ "transformer.h.25.ln_1.weight": "model-00003-of-00003.safetensors",
194
+ "transformer.h.25.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
195
+ "transformer.h.25.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
196
+ "transformer.h.25.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
197
+ "transformer.h.25.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
198
+ "transformer.h.26.attn.bias": "model-00003-of-00003.safetensors",
199
+ "transformer.h.26.attn.masked_bias": "model-00003-of-00003.safetensors",
200
+ "transformer.h.26.attn.out_proj.weight": "model-00003-of-00003.safetensors",
201
+ "transformer.h.26.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
202
+ "transformer.h.26.ln_1.bias": "model-00003-of-00003.safetensors",
203
+ "transformer.h.26.ln_1.weight": "model-00003-of-00003.safetensors",
204
+ "transformer.h.26.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
205
+ "transformer.h.26.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
206
+ "transformer.h.26.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
207
+ "transformer.h.26.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
208
+ "transformer.h.27.attn.bias": "model-00003-of-00003.safetensors",
209
+ "transformer.h.27.attn.masked_bias": "model-00003-of-00003.safetensors",
210
+ "transformer.h.27.attn.out_proj.weight": "model-00003-of-00003.safetensors",
211
+ "transformer.h.27.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
212
+ "transformer.h.27.ln_1.bias": "model-00003-of-00003.safetensors",
213
+ "transformer.h.27.ln_1.weight": "model-00003-of-00003.safetensors",
214
+ "transformer.h.27.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
215
+ "transformer.h.27.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
216
+ "transformer.h.27.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
217
+ "transformer.h.27.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
218
+ "transformer.h.28.attn.bias": "model-00003-of-00003.safetensors",
219
+ "transformer.h.28.attn.masked_bias": "model-00003-of-00003.safetensors",
220
+ "transformer.h.28.attn.out_proj.weight": "model-00003-of-00003.safetensors",
221
+ "transformer.h.28.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
222
+ "transformer.h.28.ln_1.bias": "model-00003-of-00003.safetensors",
223
+ "transformer.h.28.ln_1.weight": "model-00003-of-00003.safetensors",
224
+ "transformer.h.28.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
225
+ "transformer.h.28.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
226
+ "transformer.h.28.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
227
+ "transformer.h.28.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
228
+ "transformer.h.29.attn.bias": "model-00003-of-00003.safetensors",
229
+ "transformer.h.29.attn.masked_bias": "model-00003-of-00003.safetensors",
230
+ "transformer.h.29.attn.out_proj.weight": "model-00003-of-00003.safetensors",
231
+ "transformer.h.29.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
232
+ "transformer.h.29.ln_1.bias": "model-00003-of-00003.safetensors",
233
+ "transformer.h.29.ln_1.weight": "model-00003-of-00003.safetensors",
234
+ "transformer.h.29.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
235
+ "transformer.h.29.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
236
+ "transformer.h.29.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
237
+ "transformer.h.29.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
238
+ "transformer.h.3.attn.bias": "model-00001-of-00003.safetensors",
239
+ "transformer.h.3.attn.masked_bias": "model-00001-of-00003.safetensors",
240
+ "transformer.h.3.attn.out_proj.weight": "model-00001-of-00003.safetensors",
241
+ "transformer.h.3.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
242
+ "transformer.h.3.ln_1.bias": "model-00001-of-00003.safetensors",
243
+ "transformer.h.3.ln_1.weight": "model-00001-of-00003.safetensors",
244
+ "transformer.h.3.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
245
+ "transformer.h.3.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
246
+ "transformer.h.3.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
247
+ "transformer.h.3.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
248
+ "transformer.h.30.attn.bias": "model-00003-of-00003.safetensors",
249
+ "transformer.h.30.attn.masked_bias": "model-00003-of-00003.safetensors",
250
+ "transformer.h.30.attn.out_proj.weight": "model-00003-of-00003.safetensors",
251
+ "transformer.h.30.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
252
+ "transformer.h.30.ln_1.bias": "model-00003-of-00003.safetensors",
253
+ "transformer.h.30.ln_1.weight": "model-00003-of-00003.safetensors",
254
+ "transformer.h.30.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
255
+ "transformer.h.30.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
256
+ "transformer.h.30.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
257
+ "transformer.h.30.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
258
+ "transformer.h.31.attn.bias": "model-00003-of-00003.safetensors",
259
+ "transformer.h.31.attn.masked_bias": "model-00003-of-00003.safetensors",
260
+ "transformer.h.31.attn.out_proj.weight": "model-00003-of-00003.safetensors",
261
+ "transformer.h.31.attn.qkv_proj.weight": "model-00003-of-00003.safetensors",
262
+ "transformer.h.31.ln_1.bias": "model-00003-of-00003.safetensors",
263
+ "transformer.h.31.ln_1.weight": "model-00003-of-00003.safetensors",
264
+ "transformer.h.31.mlp.fc_in.bias": "model-00003-of-00003.safetensors",
265
+ "transformer.h.31.mlp.fc_in.weight": "model-00003-of-00003.safetensors",
266
+ "transformer.h.31.mlp.fc_out.bias": "model-00003-of-00003.safetensors",
267
+ "transformer.h.31.mlp.fc_out.weight": "model-00003-of-00003.safetensors",
268
+ "transformer.h.4.attn.bias": "model-00001-of-00003.safetensors",
269
+ "transformer.h.4.attn.masked_bias": "model-00001-of-00003.safetensors",
270
+ "transformer.h.4.attn.out_proj.weight": "model-00001-of-00003.safetensors",
271
+ "transformer.h.4.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
272
+ "transformer.h.4.ln_1.bias": "model-00001-of-00003.safetensors",
273
+ "transformer.h.4.ln_1.weight": "model-00001-of-00003.safetensors",
274
+ "transformer.h.4.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
275
+ "transformer.h.4.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
276
+ "transformer.h.4.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
277
+ "transformer.h.4.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
278
+ "transformer.h.5.attn.bias": "model-00001-of-00003.safetensors",
279
+ "transformer.h.5.attn.masked_bias": "model-00001-of-00003.safetensors",
280
+ "transformer.h.5.attn.out_proj.weight": "model-00001-of-00003.safetensors",
281
+ "transformer.h.5.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
282
+ "transformer.h.5.ln_1.bias": "model-00001-of-00003.safetensors",
283
+ "transformer.h.5.ln_1.weight": "model-00001-of-00003.safetensors",
284
+ "transformer.h.5.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
285
+ "transformer.h.5.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
286
+ "transformer.h.5.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
287
+ "transformer.h.5.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
288
+ "transformer.h.6.attn.bias": "model-00001-of-00003.safetensors",
289
+ "transformer.h.6.attn.masked_bias": "model-00001-of-00003.safetensors",
290
+ "transformer.h.6.attn.out_proj.weight": "model-00001-of-00003.safetensors",
291
+ "transformer.h.6.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
292
+ "transformer.h.6.ln_1.bias": "model-00001-of-00003.safetensors",
293
+ "transformer.h.6.ln_1.weight": "model-00001-of-00003.safetensors",
294
+ "transformer.h.6.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
295
+ "transformer.h.6.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
296
+ "transformer.h.6.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
297
+ "transformer.h.6.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
298
+ "transformer.h.7.attn.bias": "model-00001-of-00003.safetensors",
299
+ "transformer.h.7.attn.masked_bias": "model-00001-of-00003.safetensors",
300
+ "transformer.h.7.attn.out_proj.weight": "model-00001-of-00003.safetensors",
301
+ "transformer.h.7.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
302
+ "transformer.h.7.ln_1.bias": "model-00001-of-00003.safetensors",
303
+ "transformer.h.7.ln_1.weight": "model-00001-of-00003.safetensors",
304
+ "transformer.h.7.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
305
+ "transformer.h.7.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
306
+ "transformer.h.7.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
307
+ "transformer.h.7.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
308
+ "transformer.h.8.attn.bias": "model-00001-of-00003.safetensors",
309
+ "transformer.h.8.attn.masked_bias": "model-00001-of-00003.safetensors",
310
+ "transformer.h.8.attn.out_proj.weight": "model-00001-of-00003.safetensors",
311
+ "transformer.h.8.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
312
+ "transformer.h.8.ln_1.bias": "model-00001-of-00003.safetensors",
313
+ "transformer.h.8.ln_1.weight": "model-00001-of-00003.safetensors",
314
+ "transformer.h.8.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
315
+ "transformer.h.8.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
316
+ "transformer.h.8.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
317
+ "transformer.h.8.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
318
+ "transformer.h.9.attn.bias": "model-00001-of-00003.safetensors",
319
+ "transformer.h.9.attn.masked_bias": "model-00001-of-00003.safetensors",
320
+ "transformer.h.9.attn.out_proj.weight": "model-00001-of-00003.safetensors",
321
+ "transformer.h.9.attn.qkv_proj.weight": "model-00001-of-00003.safetensors",
322
+ "transformer.h.9.ln_1.bias": "model-00001-of-00003.safetensors",
323
+ "transformer.h.9.ln_1.weight": "model-00001-of-00003.safetensors",
324
+ "transformer.h.9.mlp.fc_in.bias": "model-00001-of-00003.safetensors",
325
+ "transformer.h.9.mlp.fc_in.weight": "model-00001-of-00003.safetensors",
326
+ "transformer.h.9.mlp.fc_out.bias": "model-00001-of-00003.safetensors",
327
+ "transformer.h.9.mlp.fc_out.weight": "model-00001-of-00003.safetensors",
328
+ "transformer.ln_f.bias": "model-00003-of-00003.safetensors",
329
+ "transformer.ln_f.weight": "model-00003-of-00003.safetensors",
330
+ "transformer.structure.attn_pool.attn.in_proj_bias": "model-00003-of-00003.safetensors",
331
+ "transformer.structure.attn_pool.attn.in_proj_weight": "model-00003-of-00003.safetensors",
332
+ "transformer.structure.attn_pool.attn.out_proj.bias": "model-00003-of-00003.safetensors",
333
+ "transformer.structure.attn_pool.attn.out_proj.weight": "model-00003-of-00003.safetensors",
334
+ "transformer.structure.attn_pool.kv_proj.weight": "model-00003-of-00003.safetensors",
335
+ "transformer.structure.attn_pool.latents": "model-00003-of-00003.safetensors",
336
+ "transformer.structure.attn_pool.ln_kv.bias": "model-00003-of-00003.safetensors",
337
+ "transformer.structure.attn_pool.ln_kv.weight": "model-00003-of-00003.safetensors",
338
+ "transformer.structure.attn_pool.ln_post.bias": "model-00003-of-00003.safetensors",
339
+ "transformer.structure.attn_pool.ln_post.weight": "model-00003-of-00003.safetensors",
340
+ "transformer.structure.attn_pool.ln_q.bias": "model-00003-of-00003.safetensors",
341
+ "transformer.structure.attn_pool.ln_q.weight": "model-00003-of-00003.safetensors",
342
+ "transformer.structure.attn_pool.pos_embed": "model-00003-of-00003.safetensors",
343
+ "transformer.structure.attn_pool.proj.weight": "model-00003-of-00003.safetensors",
344
+ "transformer.wte.weight": "model-00001-of-00003.safetensors"
345
+ }
346
+ }
modeling_InstructProGen.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.generation.configuration_utils import GenerationConfig
12
+ from transformers.generation.logits_process import LogitsProcessorList
13
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from transformers.generation.utils import GenerateOutput
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
+ from .configuration_progen import ProGenConfig
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ from .structure import StructureTransformer
25
+
26
+
27
+ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
28
+ dim = x.shape[-1]
29
+ if seq_len is None:
30
+ seq_len = x.shape[seq_dim]
31
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
32
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
33
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
34
+
35
+
36
+ def rotate_every_two(x):
37
+ x1 = x[:, :, :, ::2]
38
+ x2 = x[:, :, :, 1::2]
39
+ x = torch.stack((-x2, x1), axis=-1)
40
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
41
+
42
+
43
+ def apply_rotary_pos_emb(x, sincos, offset=0):
44
+ sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)
45
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
46
+ return (x * cos) + (rotate_every_two(x) * sin)
47
+
48
+
49
+ class ProGenAttention(nn.Module):
50
+ def __init__(self, config):
51
+ super().__init__()
52
+
53
+ max_positions = config.max_position_embeddings
54
+ self.register_buffer(
55
+ "bias",
56
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
57
+ 1, 1, max_positions, max_positions
58
+ ),
59
+ )
60
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
61
+
62
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
63
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
64
+
65
+ self.embed_dim = config.hidden_size
66
+ self.num_attention_heads = config.num_attention_heads
67
+ self.head_dim = self.embed_dim // self.num_attention_heads
68
+ if self.head_dim * self.num_attention_heads != self.embed_dim:
69
+ raise ValueError(
70
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
71
+ )
72
+ self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
73
+ self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
74
+
75
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
76
+ self.rotary_dim = None
77
+ if config.rotary_dim is not None:
78
+ self.rotary_dim = config.rotary_dim
79
+
80
+ def _split_heads(self, x, n_head, dim_head, mp_num):
81
+ reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head))
82
+ reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:])
83
+ return reshaped
84
+
85
+ def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
86
+ """
87
+ Merges attn_head_size dim and num_attn_heads dim into n_ctx
88
+ """
89
+ if len(tensor.shape) == 5:
90
+ tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
91
+ elif len(tensor.shape) == 4:
92
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
93
+ else:
94
+ raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
95
+ new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
96
+ return tensor.view(new_shape)
97
+
98
+ def _attn(
99
+ self,
100
+ query,
101
+ key,
102
+ value,
103
+ attention_mask=None,
104
+ head_mask=None,
105
+ ):
106
+
107
+ # compute causal mask from causal mask buffer
108
+ query_length, key_length = query.size(-2), key.size(-2)
109
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
110
+
111
+ # Keep the attention weights computation in fp32 to avoid overflow issues
112
+ query = query.to(torch.float32)
113
+ key = key.to(torch.float32)
114
+
115
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
116
+
117
+ attn_weights = attn_weights / self.scale_attn
118
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
119
+
120
+ if attention_mask is not None:
121
+ # Apply the attention mask
122
+ attn_weights = attn_weights + attention_mask
123
+
124
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
125
+ attn_weights = attn_weights.to(value.dtype)
126
+ attn_weights = self.attn_dropout(attn_weights)
127
+
128
+ # Mask heads if we want to
129
+ if head_mask is not None:
130
+ attn_weights = attn_weights * head_mask
131
+
132
+ attn_output = torch.matmul(attn_weights, value)
133
+
134
+ return attn_output, attn_weights
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states,
139
+ attention_mask=None,
140
+ layer_past=None,
141
+ head_mask=None,
142
+ use_cache=False,
143
+ output_attentions=False,
144
+ ):
145
+
146
+ qkv = self.qkv_proj(hidden_states)
147
+ # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
148
+ # mp_num = 4
149
+ mp_num = 8
150
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
151
+
152
+ local_dim = self.head_dim * self.num_attention_heads // mp_num
153
+ query, value, key = torch.split(qkv_split, local_dim, dim=-1)
154
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
155
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
156
+
157
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
158
+ value = value.permute(0, 2, 1, 3)
159
+
160
+ seq_len = key.shape[1]
161
+ offset = 0
162
+
163
+ if layer_past is not None:
164
+ offset = layer_past[0].shape[-2]
165
+ seq_len += offset
166
+
167
+ if self.rotary_dim is not None:
168
+ k_rot = key[:, :, :, : self.rotary_dim]
169
+ k_pass = key[:, :, :, self.rotary_dim :]
170
+
171
+ q_rot = query[:, :, :, : self.rotary_dim]
172
+ q_pass = query[:, :, :, self.rotary_dim :]
173
+
174
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
175
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
176
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
177
+
178
+ key = torch.cat([k_rot, k_pass], dim=-1)
179
+ query = torch.cat([q_rot, q_pass], dim=-1)
180
+ else:
181
+ sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
182
+ key = apply_rotary_pos_emb(key, sincos, offset=offset)
183
+ query = apply_rotary_pos_emb(query, sincos, offset=offset)
184
+
185
+ key = key.permute(0, 2, 1, 3)
186
+ query = query.permute(0, 2, 1, 3)
187
+
188
+ if layer_past is not None:
189
+ past_key = layer_past[0]
190
+ past_value = layer_past[1]
191
+ key = torch.cat((past_key, key), dim=-2)
192
+ value = torch.cat((past_value, value), dim=-2)
193
+
194
+ if use_cache is True:
195
+ present = (key, value)
196
+ else:
197
+ present = None
198
+
199
+ # compute self-attention: V x Softmax(QK^T)
200
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
201
+
202
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
203
+
204
+ attn_output = self.out_proj(attn_output)
205
+ attn_output = self.resid_dropout(attn_output)
206
+
207
+ outputs = (attn_output, present)
208
+ if output_attentions:
209
+ outputs += (attn_weights,)
210
+
211
+ return outputs # a, present, (attentions)
212
+
213
+
214
+ class ProGenMLP(nn.Module):
215
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
216
+ super().__init__()
217
+ embed_dim = config.n_embd
218
+
219
+ self.fc_in = nn.Linear(embed_dim, intermediate_size)
220
+ self.fc_out = nn.Linear(intermediate_size, embed_dim)
221
+
222
+ self.act = ACT2FN[config.activation_function]
223
+ self.dropout = nn.Dropout(config.resid_pdrop)
224
+
225
+ def forward(self, hidden_states):
226
+ hidden_states = self.fc_in(hidden_states)
227
+ hidden_states = self.act(hidden_states)
228
+ hidden_states = self.fc_out(hidden_states)
229
+ hidden_states = self.dropout(hidden_states)
230
+ return hidden_states
231
+
232
+
233
+ class ProGenBlock(nn.Module):
234
+ def __init__(self, config):
235
+ super().__init__()
236
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
237
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
238
+ self.attn = ProGenAttention(config)
239
+ self.mlp = ProGenMLP(inner_dim, config)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states,
244
+ layer_past=None,
245
+ attention_mask=None,
246
+ head_mask=None,
247
+ use_cache=False,
248
+ output_attentions=False,
249
+ ):
250
+ residual = hidden_states
251
+ hidden_states = self.ln_1(hidden_states)
252
+ attn_outputs = self.attn(
253
+ hidden_states,
254
+ layer_past=layer_past,
255
+ attention_mask=attention_mask,
256
+ head_mask=head_mask,
257
+ use_cache=use_cache,
258
+ output_attentions=output_attentions,
259
+ )
260
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
261
+ outputs = attn_outputs[1:]
262
+
263
+ feed_forward_hidden_states = self.mlp(hidden_states)
264
+ hidden_states = attn_output + feed_forward_hidden_states + residual
265
+
266
+ if use_cache:
267
+ outputs = (hidden_states,) + outputs
268
+ else:
269
+ outputs = (hidden_states,) + outputs[1:]
270
+
271
+ return outputs # hidden_states, present, (attentions)
272
+
273
+
274
+ class ProGenPreTrainedModel(PreTrainedModel):
275
+ """
276
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
277
+ models.
278
+ """
279
+
280
+ config_class = ProGenConfig
281
+ base_model_prefix = "transformer"
282
+ supports_gradient_checkpointing = True
283
+ is_parallelizable = True
284
+
285
+ def __init__(self, *inputs, **kwargs):
286
+ super().__init__(*inputs, **kwargs)
287
+
288
+ def _init_weights(self, module):
289
+ """Initialize the weights."""
290
+ if isinstance(module, (nn.Linear,)):
291
+ # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
292
+ # cf https://github.com/pytorch/pytorch/pull/5617
293
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
294
+ if module.bias is not None:
295
+ module.bias.data.zero_()
296
+ elif isinstance(module, nn.Embedding):
297
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
298
+ if module.padding_idx is not None:
299
+ module.weight.data[module.padding_idx].zero_()
300
+ elif isinstance(module, nn.LayerNorm):
301
+ module.bias.data.zero_()
302
+ module.weight.data.fill_(1.0)
303
+
304
+ def _set_gradient_checkpointing(self, module, value=False):
305
+ if isinstance(module, ProGenModel):
306
+ module.gradient_checkpointing = value
307
+
308
+ class ProGenModel(ProGenPreTrainedModel):
309
+ def __init__(self, config):
310
+ super().__init__(config)
311
+
312
+ self.embed_dim = config.n_embd
313
+ self.vocab_size = config.vocab_size
314
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
315
+ self.drop = nn.Dropout(config.embd_pdrop)
316
+ self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
317
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
318
+ self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
319
+
320
+ self.gradient_checkpointing = False
321
+ self.structure = StructureTransformer(**config.structure)
322
+
323
+ self.init_weights()
324
+
325
+ # Model parallel
326
+ self.model_parallel = False
327
+ self.device_map = None
328
+
329
+
330
+ def parallelize(self, device_map=None):
331
+ # Check validity of device_map
332
+ self.device_map = (
333
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
334
+ )
335
+ assert_device_map(self.device_map, len(self.h))
336
+ self.model_parallel = True
337
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
338
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
339
+ self.wte = self.wte.to(self.first_device)
340
+ # Load onto devices
341
+ for k, v in self.device_map.items():
342
+ for block in v:
343
+ cuda_device = "cuda:" + str(k)
344
+ self.h[block] = self.h[block].to(cuda_device)
345
+ # ln_f to last
346
+ self.ln_f = self.ln_f.to(self.last_device)
347
+
348
+
349
+ def deparallelize(self):
350
+ self.model_parallel = False
351
+ self.device_map = None
352
+ self.first_device = "cpu"
353
+ self.last_device = "cpu"
354
+ self.wte = self.wte.to("cpu")
355
+ for index in range(len(self.h)):
356
+ self.h[index] = self.h[index].to("cpu")
357
+ self.ln_f = self.ln_f.to("cpu")
358
+ torch.cuda.empty_cache()
359
+
360
+ def get_input_embeddings(self):
361
+ return self.wte
362
+
363
+ def set_input_embeddings(self, new_embeddings):
364
+ self.wte = new_embeddings
365
+
366
+ def forward(
367
+ self,
368
+ input_ids=None,
369
+ past_key_values=None,
370
+ attention_mask=None,
371
+ token_type_ids=None,
372
+ position_ids=None,
373
+ head_mask=None,
374
+ inputs_embeds=None,
375
+ query_embeds=None,
376
+ use_cache=None,
377
+ output_attentions=None,
378
+ output_hidden_states=None,
379
+ return_dict=None,
380
+ ):
381
+ if past_key_values is None:
382
+ # structure encode will check if input_ids contains valid
383
+ structure_embs = self.structure.encode(input_ids)
384
+ if structure_embs is not None:
385
+ input_ids = input_ids[:, self.structure.n_queries:]
386
+ else:
387
+ structure_embs = None
388
+
389
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
390
+ output_hidden_states = (
391
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
392
+ )
393
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
395
+
396
+ if input_ids is not None and inputs_embeds is not None:
397
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
398
+ elif input_ids is not None:
399
+ input_shape = input_ids.size()
400
+ input_ids = input_ids.view(-1, input_shape[-1])
401
+ batch_size = input_ids.shape[0]
402
+ elif inputs_embeds is not None:
403
+ input_shape = inputs_embeds.size()[:-1]
404
+ batch_size = inputs_embeds.shape[0]
405
+ else:
406
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
407
+
408
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
409
+
410
+ # if token_type_ids is not None:
411
+ # token_type_ids = token_type_ids.view(-1, input_shape[-1])
412
+
413
+ if position_ids is not None:
414
+ position_ids = position_ids.view(-1, input_shape[-1])
415
+
416
+ if past_key_values is None:
417
+ past_length = 0
418
+ past_key_values = tuple([None] * len(self.h))
419
+ else:
420
+ past_length = past_key_values[0][0].size(-2)
421
+
422
+ if position_ids is None:
423
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
424
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
425
+
426
+ # Attention mask.
427
+ if attention_mask is not None:
428
+ assert batch_size > 0, "batch_size has to be defined and > 0"
429
+ attention_mask = attention_mask.view(batch_size, -1)
430
+ # We create a 3D attention mask from a 2D tensor mask.
431
+ # Sizes are [batch_size, 1, 1, to_seq_length]
432
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
433
+ # this attention mask is more simple than the triangular masking of causal attention
434
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
435
+ attention_mask = attention_mask[:, None, None, :]
436
+
437
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
438
+ # masked positions, this operation will create a tensor which is 0.0 for
439
+ # positions we want to attend and -10000.0 for masked positions.
440
+ # Since we are adding it to the raw scores before the softmax, this is
441
+ # effectively the same as removing these entirely.
442
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
443
+ attention_mask = (1.0 - attention_mask) * -10000.0
444
+
445
+ # Prepare head mask if needed
446
+ # 1.0 in head_mask indicate we keep the head
447
+ # attention_probs has shape bsz x num_attention_heads x N x N
448
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
449
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
450
+
451
+ if inputs_embeds is None:
452
+ inputs_embeds = self.wte(input_ids)
453
+
454
+ if query_embeds is not None:
455
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
456
+ input_shape = inputs_embeds.size()[:-1]
457
+
458
+ if structure_embs is not None:
459
+ inputs_embeds = torch.cat([structure_embs, inputs_embeds], dim=1)
460
+ input_shape = inputs_embeds.size()[:-1]
461
+
462
+ hidden_states = inputs_embeds
463
+
464
+ # disable token_type_ids
465
+ # if token_type_ids is not None:
466
+ # token_type_embeds = self.wte(token_type_ids)
467
+ # hidden_states = hidden_states + token_type_embeds
468
+
469
+ hidden_states = self.drop(hidden_states)
470
+
471
+ output_shape = input_shape + (hidden_states.size(-1),)
472
+
473
+ presents = () if use_cache else None
474
+ all_self_attentions = () if output_attentions else None
475
+ all_hidden_states = () if output_hidden_states else None
476
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
477
+
478
+ # Model parallel
479
+ if self.model_parallel:
480
+ torch.cuda.set_device(hidden_states.device)
481
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
482
+ if layer_past is not None:
483
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
484
+ # Ensure that attention_mask is always on the same device as hidden_states
485
+ if attention_mask is not None:
486
+ attention_mask = attention_mask.to(hidden_states.device)
487
+ if isinstance(head_mask, torch.Tensor):
488
+ head_mask = head_mask.to(hidden_states.device)
489
+ if output_hidden_states:
490
+ all_hidden_states = all_hidden_states + (hidden_states,)
491
+
492
+ if self.gradient_checkpointing and self.training:
493
+
494
+ if use_cache:
495
+ # logger.warning(
496
+ # "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
497
+ # "`use_cache=False`..."
498
+ # )
499
+ use_cache = False
500
+
501
+ def create_custom_forward(module):
502
+ def custom_forward(*inputs):
503
+ # None for past_key_value
504
+ return module(*inputs, use_cache, output_attentions)
505
+
506
+ return custom_forward
507
+
508
+ outputs = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(block),
510
+ hidden_states,
511
+ None,
512
+ attention_mask,
513
+ head_mask[i],
514
+ )
515
+ else:
516
+ outputs = block(
517
+ hidden_states,
518
+ layer_past=layer_past,
519
+ attention_mask=attention_mask,
520
+ head_mask=head_mask[i],
521
+ use_cache=use_cache,
522
+ output_attentions=output_attentions,
523
+ )
524
+
525
+ hidden_states = outputs[0]
526
+ if use_cache is True:
527
+ presents = presents + (outputs[1],)
528
+
529
+ if output_attentions:
530
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
531
+
532
+ # Model Parallel: If it's the last layer for that device, put things on the next device
533
+ if self.model_parallel:
534
+ for k, v in self.device_map.items():
535
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
536
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
537
+
538
+ hidden_states = self.ln_f(hidden_states)
539
+
540
+ hidden_states = hidden_states.view(*output_shape)
541
+ # Add last hidden state
542
+ if output_hidden_states:
543
+ all_hidden_states = all_hidden_states + (hidden_states,)
544
+
545
+ if not return_dict:
546
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
547
+
548
+ return BaseModelOutputWithPast(
549
+ last_hidden_state=hidden_states,
550
+ past_key_values=presents,
551
+ hidden_states=all_hidden_states,
552
+ attentions=all_self_attentions,
553
+ )
554
+
555
+
556
+ class ProGenForCausalLM(ProGenPreTrainedModel):
557
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
558
+
559
+ def __init__(self, config):
560
+ super().__init__(config)
561
+ self.transformer = ProGenModel(config)
562
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
563
+ self.init_weights()
564
+
565
+ # Model parallel
566
+ self.model_parallel = False
567
+ self.device_map = None
568
+
569
+ def parallelize(self, device_map=None):
570
+ self.device_map = (
571
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
572
+ if device_map is None
573
+ else device_map
574
+ )
575
+ assert_device_map(self.device_map, len(self.transformer.h))
576
+ self.transformer.parallelize(self.device_map)
577
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
578
+ self.model_parallel = True
579
+
580
+ def deparallelize(self):
581
+ self.transformer.deparallelize()
582
+ self.transformer = self.transformer.to("cpu")
583
+ self.lm_head = self.lm_head.to("cpu")
584
+ self.model_parallel = False
585
+ torch.cuda.empty_cache()
586
+
587
+ def get_output_embeddings(self):
588
+ return self.lm_head
589
+
590
+ def set_output_embeddings(self, new_embeddings):
591
+ self.lm_head = new_embeddings
592
+
593
+ def prepare_inputs_for_generation(
594
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
595
+ ):
596
+ if past_key_values:
597
+ input_ids = input_ids[:, -1:]
598
+
599
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
600
+ if inputs_embeds is not None and past_key_values is None:
601
+ model_inputs = {"inputs_embeds": inputs_embeds}
602
+ else:
603
+ model_inputs = {"input_ids": input_ids}
604
+
605
+ model_inputs.update(
606
+ {
607
+ "past_key_values": past_key_values,
608
+ "use_cache": kwargs.get("use_cache"),
609
+ "attention_mask": attention_mask,
610
+ }
611
+ )
612
+ return model_inputs
613
+
614
+ def forward(
615
+ self,
616
+ input_ids=None,
617
+ past_key_values=None,
618
+ attention_mask=None,
619
+ token_type_ids=None,
620
+ position_ids=None,
621
+ head_mask=None,
622
+ inputs_embeds=None,
623
+ labels=None,
624
+ use_cache=None,
625
+ query_embeds = None,
626
+ output_attentions=None,
627
+ output_hidden_states=None,
628
+ return_dict=None,
629
+ ):
630
+ r"""
631
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
632
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
633
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
634
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
635
+ """
636
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
637
+
638
+ transformer_outputs = self.transformer(
639
+ input_ids,
640
+ past_key_values=past_key_values,
641
+ attention_mask=attention_mask,
642
+ token_type_ids=token_type_ids,
643
+ position_ids=position_ids,
644
+ head_mask=head_mask,
645
+ inputs_embeds=inputs_embeds,
646
+ query_embeds=query_embeds,
647
+ use_cache=use_cache,
648
+ output_attentions=output_attentions,
649
+ output_hidden_states=output_hidden_states,
650
+ return_dict=return_dict,
651
+ )
652
+ hidden_states = transformer_outputs[0]
653
+
654
+ # Set device for model parallelism
655
+ if self.model_parallel:
656
+ torch.cuda.set_device(self.transformer.first_device)
657
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
658
+
659
+ # make sure sampling in fp16 works correctly and
660
+ # compute loss in fp32 to match with mesh-tf version
661
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
662
+ lm_logits = self.lm_head(hidden_states).to(torch.float32)
663
+
664
+ loss = None
665
+ if labels is not None:
666
+ # Shift so that tokens < n predict n
667
+ shift_logits = lm_logits[..., :-1, :].contiguous()
668
+ shift_labels = labels[..., 1:].contiguous()
669
+ # Flatten the tokens
670
+ loss_fct = CrossEntropyLoss()
671
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
672
+
673
+ loss = loss.to(hidden_states.dtype)
674
+
675
+ if not return_dict:
676
+ output = (lm_logits,) + transformer_outputs[1:]
677
+ return ((loss,) + output) if loss is not None else output
678
+
679
+ return CausalLMOutputWithPast(
680
+ loss=loss,
681
+ logits=lm_logits,
682
+ past_key_values=transformer_outputs.past_key_values,
683
+ hidden_states=transformer_outputs.hidden_states,
684
+ attentions=transformer_outputs.attentions,
685
+ )
686
+
687
+ @staticmethod
688
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
689
+ """
690
+ This function is used to re-order the :obj:`past_key_values` cache if
691
+ :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
692
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
693
+ """
694
+ return tuple(
695
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
696
+ for layer_past in past
697
+ )
698
+
699
+ # def generate(self, inputs: Tensor | None = None, generation_config: GenerationConfig | None = None, logits_processor: LogitsProcessorList | None = None, stopping_criteria: StoppingCriteriaList | None = None, prefix_allowed_tokens_fn: Callable[[int, Tensor], List[int]] | None = None, synced_gpus: bool | None = None, assistant_model: PreTrainedModel | None = None, streamer: BaseStreamer | None = None, negative_prompt_ids: Tensor | None = None, negative_prompt_attention_mask: Tensor | None = None, **kwargs) -> GenerateOutput | LongTensor:
700
+ # return super().generate(inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<|pad|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ }
9
+ }
structure.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ import pickle
12
+ from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
+ import os
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+ class GLU(nn.Module):
23
+ def __init__(self,hidden_size):
24
+ super().__init__()
25
+ self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False)
26
+ self.norm1 = nn.LayerNorm(hidden_size)
27
+ self.act1 = nn.GELU()
28
+ self.act2 = nn.functional.silu
29
+ self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False)
30
+ self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False)
31
+ self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False)
32
+
33
+ def forward(self,x):
34
+ x = self.linear_proj(x)
35
+ x = self.act1(self.norm1(x))
36
+ x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x)
37
+ x = self.dense_4h_to_h(x)
38
+ return x
39
+ def swiglu(x):
40
+ x = torch.chunk(x, 2, dim=-1)
41
+ return nn.functional.silu(x[0]) * x[1]
42
+
43
+ class GLU_new(nn.Module):
44
+ def __init__(self,hidden_size, dropout=0.1):
45
+ super().__init__()
46
+ intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64
47
+ intermediate_size = 1280
48
+
49
+ self.act = swiglu
50
+ self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False)
51
+ self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False)
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ def forward(self,x):
55
+ x = self.dense_h_to_4h(x)
56
+ x = self.act(x)
57
+ x = self.dense_4h_to_h(x)
58
+ x = self.dropout(x)
59
+ return x
60
+
61
+
62
+ n_queries = 32
63
+ def get_abs_pos(abs_pos, tgt_size):
64
+ # abs_pos: L, C
65
+ # tgt_size: M
66
+ # return: M, C
67
+ src_size = int(math.sqrt(abs_pos.size(0)))
68
+ tgt_size = int(math.sqrt(tgt_size))
69
+ dtype = abs_pos.dtype
70
+
71
+ if src_size != tgt_size:
72
+ return F.interpolate(
73
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
74
+ size=(tgt_size, tgt_size),
75
+ mode="bicubic",
76
+ align_corners=False,
77
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
78
+ else:
79
+ return abs_pos
80
+
81
+ from einops import rearrange, repeat
82
+
83
+ def get_1d_sincos_pos_embed(embed_dim, pos):
84
+ """
85
+ embed_dim: output dimension for each position
86
+ pos: a list of positions to be encoded: size (M,)
87
+ out: (M, D)
88
+ """
89
+ assert embed_dim % 2 == 0
90
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
91
+ omega /= embed_dim / 2.
92
+ omega = 1. / 10000**omega # (D/2,)
93
+
94
+ pos = pos.reshape(-1) # (M,)
95
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
96
+
97
+ emb_sin = np.sin(out) # (M, D/2)
98
+ emb_cos = np.cos(out) # (M, D/2)
99
+
100
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
101
+ return emb
102
+
103
+ class Resampler(nn.Module):
104
+ def __init__(
105
+ self,
106
+ kv_dim,
107
+ embed_dim,
108
+ num_heads=8,
109
+ n_queries=64,
110
+ max_seqlen=1024,
111
+ perceiver_resampler_positional_emb=True,
112
+ use_GLU=False,
113
+ bos_init=False,
114
+ dropout=0.0
115
+ ):
116
+ super().__init__()
117
+ self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb
118
+
119
+ if self.perceiver_resampler_positional_emb:
120
+ assert n_queries <= max_seqlen
121
+ self.stride = max_seqlen // n_queries
122
+ # self.nan_emb = nn.Parameter(torch.randn(1, kv_dim))
123
+ # nn.init.trunc_normal_(self.nan_emb, std=.02)
124
+ pos = np.arange(max_seqlen, dtype=np.float32)
125
+ self.register_buffer(
126
+ "pos_embed",
127
+ torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float()
128
+ )
129
+ self.latents = nn.Parameter(torch.randn(n_queries, embed_dim))
130
+ if bos_init:
131
+ self.latents.load('')
132
+ else:
133
+ nn.init.trunc_normal_(self.latents, std=1e-3)
134
+
135
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
136
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
137
+ self.ln_q = nn.LayerNorm(embed_dim)
138
+ self.ln_kv = nn.LayerNorm(embed_dim)
139
+ self.ln_post = nn.LayerNorm(embed_dim)
140
+ if use_GLU:
141
+ print('GLU *********************************')
142
+ self.proj = GLU_new(embed_dim, dropout=dropout)
143
+ else:
144
+ self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
145
+
146
+ self.apply(self._init_weights)
147
+
148
+ def _init_weights(self, m):
149
+ if isinstance(m, nn.Linear):
150
+ nn.init.trunc_normal_(m.weight, std=1e-3)
151
+ if isinstance(m, nn.Linear) and m.bias is not None:
152
+ nn.init.constant_(m.bias, 0)
153
+ elif isinstance(m, nn.LayerNorm):
154
+ nn.init.constant_(m.bias, 0)
155
+ nn.init.constant_(m.weight, 1.0)
156
+
157
+ def forward(self, struc_x):
158
+ """
159
+ Args:
160
+ x (torch.Tensor): protein structure features
161
+ shape (B, L, C)
162
+ Returns:
163
+ shape (B, n, C) where n is self.num_latents
164
+ """
165
+ x = struc_x["encoder_out"]
166
+ mask = struc_x["encoder_padding_mask"]
167
+
168
+
169
+ nan_mask = torch.isnan(x)
170
+ if nan_mask.any():
171
+ x = x.masked_fill(nan_mask, 0.0)
172
+ # nan_mask = nan_mask.sum(dim=-1).bool()
173
+ # x[nan_mask] += self.nan_emb
174
+
175
+ x = self.kv_proj(x)
176
+ x = self.ln_kv(x)
177
+
178
+ b, seqlen = x.shape[:2]
179
+
180
+ latents = self.ln_q(self.latents)
181
+ if self.perceiver_resampler_positional_emb:
182
+ # TODO: interpolate
183
+ latents = latents + self.pos_embed[::self.stride].contiguous()
184
+ pos_emb = self.pos_embed[:seqlen].unsqueeze(0)
185
+ x = x + pos_emb.contiguous()
186
+
187
+ # blocks
188
+ latents = repeat(latents, "n d -> b n d", b=b)
189
+ out = self.attn(latents, x, x, key_padding_mask=~mask)[0]
190
+
191
+ out = self.ln_post(out)
192
+ out = self.proj(out)
193
+
194
+ return out
195
+
196
+ class StructureTransformer(nn.Module):
197
+
198
+ def __init__(
199
+ self,
200
+ width: int = 640,
201
+ n_queries: int = 32,
202
+ output_dim: int = 4096,
203
+ embedding_keys=set(["mpnn_emb"]),
204
+ max_seqlen: int=1024,
205
+ num_heads: int=8,
206
+ structure_emb_path_prefix='structure_emb',
207
+ **kwargs
208
+ ):
209
+ super().__init__()
210
+
211
+ self.structure_emb_path_prefix = structure_emb_path_prefix
212
+ # self.transformer = None # replace None with a pretrained strucure encoder
213
+ self.embedding_keys = embedding_keys
214
+ self.max_seqlen = max_seqlen
215
+ self.width = width
216
+ self.n_queries = n_queries
217
+
218
+ self.attn_pool = Resampler(
219
+ embed_dim=output_dim,
220
+ kv_dim=width,
221
+ n_queries=n_queries,
222
+ max_seqlen=max_seqlen,
223
+ num_heads=num_heads,
224
+ **kwargs
225
+ )
226
+
227
+ def prepare_structure(self, sample):
228
+ emb_pad = torch.zeros((self.max_seqlen, self.width))
229
+ emb_mask = torch.zeros((self.max_seqlen), dtype=bool)
230
+
231
+ if "pifold_emb" in self.embedding_keys and "pifold_mask" in sample:
232
+ mask = sample["pifold_mask"]
233
+ pifold_emb = sample["pifold_emb"]
234
+ new_pifold_emb = pifold_emb.new_zeros(mask.shape[0], pifold_emb.shape[1]).fill_(float("nan"))
235
+ new_pifold_emb[mask > 0] = pifold_emb
236
+ sample["pifold_emb"] = new_pifold_emb
237
+
238
+ ### domians ###
239
+ emb = []
240
+ for ek in self.embedding_keys:
241
+ if ek in sample:
242
+ if isinstance( sample[ek], List):
243
+ emb.append(torch.cat(sample[ek]))
244
+ else:
245
+ emb.append(sample[ek])
246
+ # emb = [sample[ek] for ek in self.embedding_keys if ek in sample]
247
+ emb = torch.cat(emb, dim=-1)
248
+
249
+ emb_pad[:len(emb)] = emb
250
+ emb_mask[:len(emb)] = 1
251
+ return emb_pad, emb_mask
252
+
253
+ def forward(self, x):
254
+
255
+ # x = self.transformer(x)
256
+ x = self.attn_pool(x)
257
+
258
+ return x
259
+
260
+ def encode(self, structure_paths: List[str]):
261
+ structure_embs = []
262
+ structure_mask = []
263
+
264
+ for structure_path in structure_paths:
265
+ structure_path = [chr(s) for s in structure_path[:self.n_queries].tolist() if s > 0]
266
+ structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path))
267
+ if not os.path.exists(structure_path):
268
+ print('no structure found')
269
+ return None
270
+
271
+ with open(structure_path, 'rb') as f:
272
+ structure, struc_mask = self.prepare_structure(pickle.load(f))
273
+
274
+
275
+ structure_embs.append(structure)
276
+ structure_mask.append(struc_mask)
277
+
278
+ structure_embs = torch.stack(structure_embs, dim=0).to(
279
+ device=next(self.attn_pool.parameters()).device,
280
+ dtype=next(self.attn_pool.parameters()).dtype)
281
+ structure_mask = torch.stack(structure_mask, dim=0).to(
282
+ device=next(self.attn_pool.parameters()).device)
283
+
284
+ return self({
285
+ 'encoder_out': structure_embs,
286
+ 'encoder_padding_mask': structure_mask
287
+ })
tokenization_iPLM.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ from transformers import PreTrainedTokenizerFast
3
+ from tokenizers.processors import TemplateProcessing
4
+ from tokenizers import Tokenizer
5
+ from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy
6
+ from transformers.utils import PaddingStrategy, TensorType
7
+ import torch
8
+
9
+ def create_tokenizer_custom(file):
10
+ with open(file, 'r') as f:
11
+ return Tokenizer.from_str(f.read())
12
+
13
+
14
+ class iPLMTokenizer(PreTrainedTokenizerFast):
15
+ def __init__(self, n_queries, use_structure=True, parallel=False, **kwargs):
16
+ super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs)
17
+ self.add_special_tokens({'pad_token': '<|pad|>'})
18
+ self.use_structure = use_structure
19
+ self.n_queries = n_queries if use_structure else 0
20
+ self.parallel = parallel
21
+ def __call__(
22
+ self,
23
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
24
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
25
+ text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
26
+ text_pair_target: Optional[
27
+ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
28
+ ] = None,
29
+ add_special_tokens: bool = True,
30
+ padding: Union[bool, str, PaddingStrategy] = False,
31
+ truncation: Union[bool, str, TruncationStrategy] = None,
32
+ max_length: Optional[int] = None,
33
+ stride: int = 0,
34
+ is_split_into_words: bool = False,
35
+ pad_to_multiple_of: Optional[int] = None,
36
+ return_tensors: Optional[Union[str, TensorType]] = None,
37
+ return_token_type_ids: Optional[bool] = None,
38
+ return_attention_mask: Optional[bool] = None,
39
+ return_overflowing_tokens: bool = False,
40
+ return_special_tokens_mask: bool = False,
41
+ return_offsets_mapping: bool = False,
42
+ return_length: bool = False,
43
+ verbose: bool = True,
44
+ **kwargs,
45
+ ) -> BatchEncoding:
46
+
47
+ raw_text = []
48
+
49
+ if not isinstance(text, list):
50
+ text = [text]
51
+
52
+ if self.use_structure:
53
+ attn_mask_prefix = torch.zeros((len(text), self.n_queries), dtype=bool)
54
+ input_ids_prefix = torch.zeros((len(text), self.n_queries), dtype=int)
55
+
56
+ for i in range(len(text)):
57
+ if '|' in text[i]:
58
+
59
+ res = text[i].split('|')
60
+ raw_text.append(res[1])
61
+
62
+ if self.use_structure:
63
+ # covert and pad structure id to ascii
64
+ structure_id = torch.tensor([ord(c) for c in res[0]])
65
+ input_ids_prefix[i, :len(structure_id)] = structure_id
66
+
67
+ attn_mask_prefix[i] = True
68
+ else:
69
+ raw_text.append(text)
70
+
71
+ batch = super().__call__(raw_text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
72
+
73
+ if self.use_structure:
74
+ batch['attention_mask'] = torch.cat([attn_mask_prefix, batch['attention_mask']], dim=1)
75
+ batch['input_ids'] = torch.cat([input_ids_prefix, batch['input_ids']], dim=1)
76
+
77
+ if "token_type_ids" in batch:
78
+ del batch["token_type_ids"]
79
+
80
+ return batch
tokenizer.json ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<|pad|>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<|bos|>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<|eos|>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ }
33
+ ],
34
+ "normalizer": null,
35
+ "pre_tokenizer": {
36
+ "type": "ByteLevel",
37
+ "add_prefix_space": false,
38
+ "trim_offsets": true,
39
+ "use_regex": true
40
+ },
41
+ "post_processor": {
42
+ "type": "ByteLevel",
43
+ "add_prefix_space": true,
44
+ "trim_offsets": true,
45
+ "use_regex": true
46
+ },
47
+ "decoder": {
48
+ "type": "ByteLevel",
49
+ "add_prefix_space": true,
50
+ "trim_offsets": true,
51
+ "use_regex": true
52
+ },
53
+ "model": {
54
+ "type": "BPE",
55
+ "dropout": null,
56
+ "unk_token": null,
57
+ "continuing_subword_prefix": null,
58
+ "end_of_word_suffix": null,
59
+ "fuse_unk": false,
60
+ "byte_fallback": false,
61
+ "vocab": {
62
+ "<|pad|>": 0,
63
+ "<|bos|>": 1,
64
+ "<|eos|>": 2,
65
+ "1": 3,
66
+ "2": 4,
67
+ "A": 5,
68
+ "B": 6,
69
+ "C": 7,
70
+ "D": 8,
71
+ "E": 9,
72
+ "F": 10,
73
+ "G": 11,
74
+ "H": 12,
75
+ "I": 13,
76
+ "K": 14,
77
+ "L": 15,
78
+ "M": 16,
79
+ "N": 17,
80
+ "O": 18,
81
+ "P": 19,
82
+ "Q": 20,
83
+ "R": 21,
84
+ "S": 22,
85
+ "T": 23,
86
+ "U": 24,
87
+ "V": 25,
88
+ "W": 26,
89
+ "X": 27,
90
+ "Y": 28,
91
+ "Z": 29
92
+ },
93
+ "merges": []
94
+ }
95
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_queries": 256,
3
+ "use_structure": true,
4
+ "tokenizer_class": "iPLMTokenizer",
5
+ "auto_map": {
6
+ "AutoTokenizer": [
7
+ "tokenization_iPLM.iPLMTokenizer",
8
+ null
9
+ ]
10
+ }
11
+ }