Bo1015 commited on
Commit
0dce0bd
1 Parent(s): a5de5b9

Upload 27 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resources/demo.gif filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,109 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MSAGPT
2
+
3
+ <table>
4
+ <tr>
5
+ <td>
6
+ <h2>MSAGPT</h2>
7
+ <p>📖 Paper: <a href="xxx">MSAGPT: Neural Prompting Protein Structure Prediction via MSA Generative Pre-Training</a></p>
8
+ <p><b>MSAGPT</b> is a powerful protein language model (PLM). MSAGPT has 3 billion parameters with three versions of the model, MSAGPT, MSAGPT-Sft, and MSAGPT-Dpo, <b>supporting zero-shot and few-shot MSA generation</b>.</p>
9
+ <p><b>MSAGPT achieves state-of-the-art structural prediction performance on natural MSA-scarce scenarios</b>.</p>
10
+ </td>
11
+ </tr>
12
+ </table>
13
+
14
+
15
+ ## Overall Framework
16
+ <p align="center">
17
+ <img src="resources/overall_frame.png" alt="描述文字" style="display: block; margin: auto; width: 90%;">
18
+ </p>
19
+
20
+ ## Visualized Cases
21
+ Visualization of improved structure prediction compared with nature MSA.
22
+ <font color=orange>Yellow</font>: Ground truth;
23
+ <font color=purple>Purple</font>: Predictions based on MSA generated by MSAGPT;
24
+ <font color=cyan>Cyan</font>: Predictions from MSA generated by natural MSA.
25
+
26
+ <p align="center">
27
+ <img src="resources/app_case.png" alt="描述文字" style="display: block; margin: auto; width: 90%;">
28
+ </p>
29
+
30
+
31
+ ## Get Started:
32
+
33
+ ### Option 1:Deploy MSAGPT by yourself
34
+
35
+ We support GUI for model inference.
36
+
37
+ First, we need to install the dependencies.
38
+
39
+ ```bash
40
+ # CUDA >= 11.8
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ #### Model List
45
+ You can choose to manually download the necessary weights. Then UNZIP it and put it into the **checkpoints** folder.
46
+
47
+ | Model | Type | Seq Length | Download |
48
+ |------------------|------|------------|-----------------------------------------------------------------------------------------------------------------------------------------|
49
+ | MSAGPT | Base | 16K | [🤗 Huggingface](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) [🔨 SwissArmyTransformer](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) |
50
+ | MSAGPT-SFT | Sft | 16K | [🤗 Huggingface](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) [🔨 SwissArmyTransformer](https://cloud.tsinghua.edu.cn/f/32da3eadf6e042aab2fa/?dl=1) |
51
+ | MSAGPT-DPO | Rlhf | 16K | [🤗 Huggingface](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) [🔨 SwissArmyTransformer](https://cloud.tsinghua.edu.cn/f/ebfc954a4cd24cef9243/?dl=1) | | |
52
+
53
+
54
+ #### Situation 1.1 CLI (SAT version)
55
+
56
+ Run CLI demo via:
57
+
58
+ ```bash
59
+ # Online Chat
60
+ bash scripts/cli_sat.sh --from_pretrained ./checkpoints/MSAGPT-DPO --input-source chat --stream_chat --max-gen-length 1024
61
+ ```
62
+
63
+ The program will automatically interact in the command line. You can generate replies entering the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "\<M\>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG\<M\>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts, and pressing enter. Enter `stop` to stop the program. The chat CLI looks like:
64
+ <p align="center">
65
+ <img src="resources/demo.gif" alt="描述文字" style="display: block; margin: auto; width: 90%;">
66
+ </p>
67
+
68
+
69
+ You can also enable the offline generation by set the **--input-source \<your input file\>** and **--output-path \<your output path\>**.
70
+ We set an input file example: *msa_input*.
71
+ ```bash
72
+ # Offline Generation
73
+ bash scripts/cli_sat.sh --from_pretrained ./checkpoints/MSAGPT-DPO --input-source <your input file> --output-path <your output path> --max-gen-length 1024
74
+ ```
75
+
76
+ #### Situation 1.2 CLI (Huggingface version)
77
+ (TODO)
78
+
79
+ #### Situation 1.3 Web Demo
80
+ (TODO)
81
+
82
+ ### Option 2:Finetuning MSAGPT
83
+
84
+ (TODO)
85
+
86
+ ### Hardware requirement
87
+
88
+ * Model Inference:
89
+ For BF16: 1 * A100(80G)
90
+
91
+ * Finetuning:
92
+
93
+ For BF16: 4 * A100(80G) *[Recommend]*.
94
+
95
+
96
+ ## License
97
+
98
+ The code in this repository is open source under the [Apache-2.0 license](./LICENSE).
99
+
100
+ If you find our work helpful, please consider citing the our paper
101
+
102
+ ```
103
+ @article{chen2024msagpt,
104
+ title={MSAGPT: Neural Prompting Protein Structure Prediction via MSA Generative Pre-Training},
105
+ author={Chen, Bo and Bei, Zhilei and Cheng, Xingyi and Li, Pan and Tang, Jie and Song, Le},
106
+ journal={arXiv preprint arXiv:2406.05347},
107
+ year={2024}
108
+ }
109
+ ```
README_zh.md ADDED
File without changes
checkpoints/MSAGPT-DPO/1/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3507871a00564c0be3a697678f521eccee2efb2d77577b0bc009d766b8f02a4
3
+ size 5721204666
checkpoints/MSAGPT-DPO/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
checkpoints/MSAGPT-DPO/model_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_class": "MSAGPT",
3
+ "tokenizer_type": "ProteinTokenizer",
4
+ "num_layers": 36,
5
+ "hidden_size": 2560,
6
+ "inner_hidden_size": 6832,
7
+ "num_attention_heads": 40,
8
+ "vocab_size": 128,
9
+ "layernorm_order": "post",
10
+ "model_parallel_size": 1,
11
+ "max_sequence_length": 2048,
12
+ "untie_head": true,
13
+ "head_num": 2,
14
+ "moe": false,
15
+ "expert": 1
16
+ }
checkpoints/MSAGPT-SFT/1/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19b7a79194615affec18617b2854602f2b77f053b80b44b31f6fd79bfb38ae68
3
+ size 5721204666
checkpoints/MSAGPT-SFT/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
checkpoints/MSAGPT-SFT/model_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_class": "MSAGPT",
3
+ "tokenizer_type": "ProteinTokenizer",
4
+ "num_layers": 36,
5
+ "hidden_size": 2560,
6
+ "inner_hidden_size": 6832,
7
+ "num_attention_heads": 40,
8
+ "vocab_size": 128,
9
+ "layernorm_order": "post",
10
+ "model_parallel_size": 1,
11
+ "max_sequence_length": 2048,
12
+ "untie_head": true,
13
+ "head_num": 2,
14
+ "moe": false,
15
+ "expert": 1
16
+ }
checkpoints/MSAGPT/1/mp_rank_00_model_states.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daaec07dca52dda4eaee8442d02c9c0f821a5e8ad81cbd280490f50f8f16e205
3
+ size 5721204666
checkpoints/MSAGPT/latest ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
checkpoints/MSAGPT/model_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_class": "MSAGPT",
3
+ "tokenizer_type": "ProteinTokenizer",
4
+ "num_layers": 36,
5
+ "hidden_size": 2560,
6
+ "inner_hidden_size": 6832,
7
+ "num_attention_heads": 40,
8
+ "vocab_size": 128,
9
+ "layernorm_order": "post",
10
+ "model_parallel_size": 1,
11
+ "max_sequence_length": 2048,
12
+ "untie_head": true,
13
+ "head_num": 2,
14
+ "moe": false,
15
+ "expert": 1
16
+ }
cli_sat.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import stat
4
+ import re
5
+ import time
6
+ import argparse
7
+ import numpy as np
8
+
9
+ from functools import partial
10
+ from typing import List, Tuple
11
+
12
+ import torch.distributed as dist
13
+ from sat.helpers import print_rank0
14
+ from sat import mpu, get_args, get_tokenizer
15
+ from utils import AdvancedBaseStrategy, BeamSearchStrategy
16
+ from model_utils import MSAGPT, FineTuneMSAGPT
17
+ from utils import chat_api
18
+
19
+
20
+
21
+ if __name__ == "__main__":
22
+ py_parser = argparse.ArgumentParser(add_help=False)
23
+ py_parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
24
+ py_parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
25
+ py_parser.add_argument("--max-gen-length", type=int, default=512, help="The minimum length each blank should generate.")
26
+ py_parser.add_argument("--is-valid", action="store_true", help="Print all output generated by beam search strategy.")
27
+ py_parser.add_argument("--print-all-beams", action="store_true", help="Print all output generated by beam search strategy.")
28
+ py_parser.add_argument("--multiline_stream", action="store_true", help="streaming multiline output.")
29
+ py_parser.add_argument("--no-gap", action="store_true", help="do not generate gaps.")
30
+ py_parser.add_argument("--from_pretrained", type=str, default="./checkpoints/MSAGPT", help='pretrained ckpt')
31
+ py_parser.add_argument("--chinese", action='store_true', help='Chinese interface')
32
+ py_parser.add_argument("--stream_chat", action='store_true', help='streaming output')
33
+
34
+
35
+ py_parser = MSAGPT.add_model_specific_args(py_parser)
36
+ known, args_list = py_parser.parse_known_args()
37
+ args = get_args(args_list)
38
+ args = argparse.Namespace(**vars(args), **vars(known))
39
+ model, args = MSAGPT.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
40
+ model.eval()
41
+ rank = int(os.environ.get('RANK', 0))
42
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
43
+ if torch.cuda.is_available():
44
+ model = model.to('cuda')
45
+ from utils import proteinglm_tokenizer
46
+ tokenizer = proteinglm_tokenizer()
47
+
48
+ end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
49
+ # Get rid of all invalid tokens
50
+ invalid_slices = [0,26,28,29,30,31,32]
51
+ if args.no_gap:
52
+ invalid_slices.append(tokenizer.TokenToId('-'))
53
+ if args.sampling_strategy == "BaseStrategy":
54
+ assert not args.print_all_beams, "BaseStrategy don't support print all beams."
55
+ strategy = AdvancedBaseStrategy(
56
+ batch_size=1, invalid_slices = invalid_slices, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, min_gen_length=args.min_gen_length, no_repeat_ngram_size=args.no_repeat_ngram_size, end_tokens=end_tokens
57
+ )
58
+ elif args.sampling_strategy == "BeamSearchStrategy":
59
+ strategy = BeamSearchStrategy(
60
+ 1,
61
+ args.num_beams,
62
+ length_penalty=args.length_penalty,
63
+ consider_end=True,
64
+ end_tokens=end_tokens,
65
+ invalid_slices=invalid_slices,
66
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
67
+ min_gen_length=args.min_gen_length,
68
+ deterministic=True
69
+ )
70
+ else:
71
+ raise ValueError(f"unknown strategy {args.sampling_strategy}")
72
+
73
+
74
+
75
+ if args.input_source == 'chat':
76
+ if args.chinese:
77
+ if rank == 0:
78
+ print('欢迎使用 MSAGPT-CLI ,输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以"<M>"相连),例如:"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG",其中"PEGKQGDPGIPGEPGPPGPPGPQGARGPPG"为主序列,"VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG"为MSA prompt。 stop 终止程序'.center(20, "*"))
79
+ else:
80
+ if rank == 0:
81
+ print('Welcome to MSAGPT-CLI. Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by "<M>"), for example: "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG<M>VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG", where "PEGKQGDPGIPGEPGPPGPPGPQGARGPPG" is the main sequence, and "VTVEFVNSCLIGDMGVDGPPGQQGQPGPPG" are MSA prompts. Type "stop" to end the program.'.center(20,"*"))
82
+ with torch.no_grad():
83
+ while True:
84
+ if args.chinese:
85
+ if rank == 0:
86
+ protein_input = input("请输入需要生成虚拟MSA的蛋白序列(或加上少量MSA作为prompt,以'<M>'相连):")
87
+ else:
88
+ protein_input = None
89
+ else:
90
+ if rank == 0:
91
+ protein_input = input("Enter the protein sequence you need to generate virtual MSAs (or add a few MSAs as a prompt, connected by '<M>': ")
92
+ else:
93
+ protein_input = None
94
+ if world_size > 1:
95
+ torch.distributed.broadcast_object(protein_input, 0)
96
+ protein_input = protein_input.strip()
97
+ assert protein_input is not None
98
+
99
+ if protein_input == 'stop':
100
+ break
101
+
102
+ try:
103
+ response = chat_api(
104
+ args=args,
105
+ query=protein_input,
106
+ model=model,
107
+ tokenizer=tokenizer,
108
+ strategy=strategy
109
+ )
110
+ except Exception as e:
111
+ print(e)
112
+ break
113
+ if rank == 0 and not args.stream_chat:
114
+ if args.chinese:
115
+ print(f"{'生成的MSA'.center(20, '*')}")
116
+ else:
117
+ print(f"{'Virtual MSA'.center(20, '*')}")
118
+ if args.print_all_beams:
119
+ for idx, gen in enumerate(response):
120
+ out_str = f"Beam: {idx}".center(11,'@')
121
+ print(out_str)
122
+ for _ in gen:
123
+ print(_)
124
+ print()
125
+ else:
126
+ response = response[0]
127
+ for _ in response:
128
+ print(_)
129
+ print()
130
+ else:
131
+ chat_api(
132
+ args=args,
133
+ model=model,
134
+ tokenizer=tokenizer,
135
+ strategy=strategy
136
+ )
model_utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model_proteinglm_clm import ProteinGLMForGeneration
2
+ from .model_msagpt import MSAGPT, FineTuneMSAGPT
model_utils/model_msagpt.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ import torch
4
+ from torch.nn import functional as F
5
+ import torch.nn as nn
6
+
7
+ from .model_proteinglm_clm import ProteinGLMForGeneration
8
+
9
+
10
+ class MSAGPT(ProteinGLMForGeneration):
11
+ def __init__(self, args, transformer=None, **kwargs):
12
+ super().__init__(
13
+ args,
14
+ transformer=transformer,
15
+ **kwargs
16
+ )
17
+
18
+ @classmethod
19
+ def add_model_specific_args(cls, parser):
20
+ group = parser.add_argument_group('MSAGPT-inference', 'MSAGPT inference Configurations')
21
+ return super().add_model_specific_args(parser)
22
+
23
+ class FineTuneMSAGPT(MSAGPT):
24
+ def __init__(self, args, transformer=None, **kwargs):
25
+ super().__init__(
26
+ args,
27
+ transformer=transformer,
28
+ **kwargs
29
+ )
30
+ pass
model_utils/model_proteinglm_clm.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import copy
3
+ import torch
4
+ from torch.nn import functional as F
5
+ import torch.nn as nn
6
+ import contextlib
7
+
8
+ from sat import mpu
9
+ from sat.transformer_defaults import standard_attention, attention_fn_default
10
+ from sat.mpu.utils import split_tensor_along_last_dim, divide
11
+ from sat.mpu.layers import ColumnParallelLinear
12
+ from sat.model.base_model import BaseModel, BaseMixin
13
+ from sat.model.position_embedding import RotaryEmbedding
14
+ from sat.model.position_embedding import apply_rotary_pos_emb_index
15
+ from sat.ops import LayerNorm
16
+
17
+
18
+ class RotaryEmbeddingMixin(BaseMixin):
19
+ def __init__(
20
+ self,
21
+ fp16,
22
+ hidden_size,
23
+ num_attention_heads,
24
+ model_parallel_size,
25
+ rotary_embedding_2d=True,
26
+ ):
27
+ super().__init__()
28
+ hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
29
+ self.hidden_size_per_attention_head = hidden_size_per_attention_head
30
+ self.rotary_embedding_2d = rotary_embedding_2d
31
+ self.num_attention_heads_per_partition = divide(num_attention_heads, model_parallel_size)
32
+ self.rotary_emb = RotaryEmbedding(
33
+ # hidden_size_per_attention_head,
34
+ hidden_size_per_attention_head // 2
35
+ if rotary_embedding_2d
36
+ else hidden_size_per_attention_head,
37
+ base=10000,
38
+ precision=torch.half if fp16 else torch.bfloat16,
39
+ learnable=False,
40
+ device=torch.cuda.current_device(),
41
+ )
42
+
43
+
44
+ def attention_forward(self, hidden_states, mask, **kw_args):
45
+ attn = self.transformer.layers[kw_args["layer_id"]].attention
46
+ attention_fn = attention_fn_default
47
+ if "attention_fn" in attn.hooks:
48
+ attention_fn = attn.hooks["attention_fn"]
49
+
50
+ # [seq, b, 3 * hn * np]
51
+ mixed_raw_layer = attn.query_key_value(hidden_states)
52
+
53
+ # [seq, b, (np * 3 * hn)] --> [seq, b, np, 3 * hn]
54
+ new_tensor_shape = mixed_raw_layer.size()[:-1] + (
55
+ self.num_attention_heads_per_partition,
56
+ 3 * self.hidden_size_per_attention_head,
57
+ )
58
+ mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
59
+
60
+ # [sq, b, np, hn]
61
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
62
+ # print(key_layer.shape)
63
+ dropout_fn = attn.attention_dropout if attn.training else None
64
+ if self.rotary_embedding_2d:
65
+ q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
66
+ k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
67
+ cos, sin = self.rotary_emb(q1, seq_len=kw_args["position_ids"].max() + 1)
68
+ position_ids, block_position_ids = \
69
+ kw_args["position_ids"][:, 0, :].transpose(0, 1).contiguous(), \
70
+ kw_args["position_ids"][:, 1, :].transpose(0, 1).contiguous()
71
+ q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
72
+ q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
73
+ query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
74
+ key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
75
+ else:
76
+ kw_args["position_ids"] = kw_args["position_ids"].transpose(0, 1)
77
+ cos, sin = self.rotary_emb(value_layer, seq_len=kw_args["position_ids"].max() + 1)
78
+ query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, kw_args["position_ids"])
79
+
80
+ context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
81
+ output = attn.dense(context_layer)
82
+
83
+ if attn.training:
84
+ output = attn.output_dropout(output)
85
+
86
+ return output
87
+
88
+
89
+ class GEGLU(torch.nn.Module):
90
+ def __init__(self):
91
+ super().__init__()
92
+ self.activation_fn = F.gelu
93
+
94
+ def forward(self, x):
95
+ # dim=-1 breaks in jit for pt<1.10
96
+ x1, x2 = x.chunk(2, dim=(x.ndim - 1))
97
+ return x1 * self.activation_fn(x2)
98
+
99
+
100
+ class DeepNormWithGLUMixin(BaseMixin):
101
+ def __init__(self, num_layers, hidden_size, inner_hidden_size=None):
102
+ super().__init__()
103
+ self.num_layers = num_layers
104
+ self.hidden_size = hidden_size
105
+ if inner_hidden_size is None:
106
+ inner_hidden_size = 4 * hidden_size * 2 // 3
107
+ self.inner_hidden_size = inner_hidden_size
108
+
109
+ def reinit(self):
110
+ for layer in self.transformer.layers:
111
+ del layer.mlp.dense_h_to_4h
112
+ layer.mlp.dense_h_to_4h = ColumnParallelLinear(
113
+ self.hidden_size,
114
+ 2 * self.inner_hidden_size,
115
+ gather_output=False,
116
+ bias=True,
117
+ params_dtype=torch.half,
118
+ module=self,
119
+ name="dense_h_to_4h",
120
+ skip_init=True,
121
+ )
122
+ del layer.mlp.activation_func
123
+ layer.mlp.activation_func = GEGLU()
124
+
125
+ def layer_forward(self, hidden_states, mask, *args, **kw_args):
126
+ """
127
+ hidden_states: [seq_len, batch, hidden_size]
128
+ mask: [(1, 1), seq_len, seq_len]
129
+ """
130
+ layer = self.transformer.layers[kw_args["layer_id"]]
131
+ # Layer norm at the begining of the transformer layer.
132
+
133
+ attention_input = layer.input_layernorm(hidden_states)
134
+
135
+ # Self attention.
136
+ attention_output = layer.attention(attention_input, mask, **kw_args)
137
+
138
+ # Residual connection.
139
+ alpha = (2 * self.num_layers) ** 0.5
140
+ hidden_states = attention_input * alpha + attention_output
141
+
142
+ mlp_input = layer.post_attention_layernorm(hidden_states)
143
+
144
+ # MLP.
145
+ mlp_output = layer.mlp(mlp_input, **kw_args)
146
+
147
+ # Second residual connection.
148
+ output = mlp_input * alpha + mlp_output
149
+
150
+ return output
151
+
152
+
153
+ class SelfAttentionWithFP32SoftmaxMixin(BaseMixin):
154
+ def __init__(self, fp16, hidden_size, num_attention_heads, model_parallel_size):
155
+ super().__init__()
156
+ self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
157
+ self.hidden_size_per_partition = divide(hidden_size, model_parallel_size)
158
+ self.scale_mask_softmax = None
159
+ self.fp16 = fp16
160
+
161
+ @staticmethod
162
+ def attention_mask_func(attention_scores, attention_mask):
163
+ attention_scores.masked_fill_(attention_mask, -10000.0)
164
+ return attention_scores
165
+
166
+ def attention_fn(
167
+ self,
168
+ query_layer,
169
+ key_layer,
170
+ value_layer,
171
+ attention_mask,
172
+ attention_dropout=None,
173
+ log_attention_weights=None,
174
+ scaling_attention_score=True,
175
+ mems=None,
176
+ **kwargs
177
+ ):
178
+
179
+ mem = mems[kwargs["layer_id"]] if mems is not None else None
180
+
181
+ # seqlen, batch, head, hidden_size
182
+ seq_len, b, nh, hidden_size = key_layer.shape
183
+
184
+ # stack, seqlen, b, head, hidden
185
+ # b, seqlen, stack, head, hidden
186
+ cache_kv = (
187
+ torch.stack((key_layer, value_layer))
188
+ .permute(2, 1, 0, 3, 4)
189
+ .detach()
190
+ .contiguous()
191
+ .view(b, seq_len, nh * hidden_size * 2)
192
+ )
193
+ kwargs["output_this_layer"]["mem_kv"] = cache_kv
194
+
195
+ if mem is not None: # the first time, mem is None
196
+ # might change batch_size
197
+ # b, seqlen, stack, head, hidden -> stack, seqlen, b, head, hidden
198
+ mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 1, 0, 3, 4)
199
+ memk, memv = mem[0], mem[1]
200
+ key_layer = torch.cat((memk, key_layer), dim=0)
201
+ value_layer = torch.cat((memv, value_layer), dim=0)
202
+
203
+
204
+ # check if use flash attention
205
+ is_low_triangle = (attention_mask == ~torch.ones_like(attention_mask, dtype=torch.bool).tril()).all()
206
+ is_full = (attention_mask is None) or (attention_mask == 0).all()
207
+ if int(torch.__version__.split('.')[0]) >= 2 and (is_full or is_low_triangle):
208
+ # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None.
209
+ dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
210
+ #[b, np, sq, hn]
211
+ query_layer, key_layer, value_layer = query_layer.permute(1,2,0,3).contiguous(), key_layer.permute(1,2,0,3).contiguous(), value_layer.permute(1,2,0,3).contiguous()
212
+ batch_size, num_query_heads = query_layer.shape[:2] # [b, np, s, hn]
213
+ num_kv_heads = key_layer.shape[1] # [b, np, s, hn]
214
+ key_layer = key_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *key_layer.shape[2:])
215
+ value_layer = value_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *value_layer.shape[2:])
216
+
217
+ if dropout_p > 0 and mpu.get_cuda_rng_tracker is not None:
218
+ context = mpu.get_cuda_rng_tracker().fork()
219
+ else:
220
+ context = contextlib.nullcontext()
221
+
222
+ with context:
223
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
224
+ query_layer, key_layer, value_layer,
225
+ attn_mask=None,
226
+ dropout_p=dropout_p,
227
+ is_causal=not is_full
228
+ )
229
+
230
+
231
+ #[sq, b, np, hn]
232
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
233
+
234
+ # [sq, b, np, hn] --> [sq, b, hp]
235
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
236
+ context_layer = context_layer.view(*new_context_layer_shape)
237
+ return context_layer
238
+
239
+ else:
240
+ # standard attention
241
+ # [b, np, sq, sk]
242
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
243
+
244
+ query_key_layer_scaling_coeff = float(kwargs["layer_id"] + 1)
245
+
246
+
247
+ if scaling_attention_score:
248
+ query_layer = query_layer / (math.sqrt(self.hidden_size_per_attention_head) * query_key_layer_scaling_coeff)
249
+ # ===================================
250
+ # Raw attention scores. [b, np, s, s]
251
+ # ===================================
252
+ # [sq, b, np, hn] -> [sq, b * np, hn]
253
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
254
+ # [sk, b, np, hn] -> [sk, b * np, hn]
255
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
256
+
257
+ matmul_result = torch.empty(
258
+ output_size[0] * output_size[1],
259
+ output_size[2],
260
+ output_size[3],
261
+ dtype=query_layer.dtype,
262
+ device=torch.cuda.current_device(),
263
+ )
264
+
265
+ matmul_result = torch.baddbmm(
266
+ matmul_result,
267
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
268
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
269
+ beta=0.0,
270
+ alpha=1.0,
271
+ )
272
+
273
+ # change view to [b, np, sq, sk]
274
+ attention_scores = matmul_result.view(*output_size)
275
+
276
+ if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
277
+ # if auto-regressive, skip
278
+ attention_scores.masked_fill_(attention_mask.bool(), -float("inf"))
279
+
280
+ attention_scores = attention_scores.float()
281
+ attention_scores = attention_scores * query_key_layer_scaling_coeff
282
+
283
+
284
+ attention_probs = F.softmax(attention_scores, dim=-1)
285
+
286
+ if self.fp16:
287
+ attention_probs = attention_probs.half()
288
+ else:
289
+ attention_probs = attention_probs.bfloat16()
290
+
291
+ if attention_dropout is not None:
292
+ if mpu.get_cuda_rng_tracker() is not None:
293
+ with mpu.get_cuda_rng_tracker().fork():
294
+ attention_probs = attention_dropout(attention_probs)
295
+ else:
296
+ attention_probs = attention_dropout(attention_probs)
297
+
298
+ # =========================
299
+ # Context layer. [sq, b, hp]
300
+ # =========================
301
+
302
+ # value_layer -> context layer.
303
+ # [sk, b, np, hn] --> [b, np, sq, hn]
304
+
305
+ # context layer shape: [b, np, sq, hn]
306
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
307
+
308
+ # change view [sk, b * np, hn]
309
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
310
+
311
+ # change view [b * np, sq, sk]
312
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
313
+ # matmul: [b * np, sq, hn]
314
+
315
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
316
+
317
+ # change view [b, np, sq, hn]
318
+ context_layer = context_layer.view(*output_size)
319
+
320
+ # [b, np, sq, hn] --> [sq, b, np, hn]
321
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
322
+
323
+ # [sq, b, np, hn] --> [sq, b, hp]
324
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
325
+ context_layer = context_layer.view(*new_context_layer_shape)
326
+ return context_layer
327
+
328
+
329
+
330
+ class FinalForwardMixin(BaseMixin):
331
+ def __init__(self):
332
+ super().__init__()
333
+
334
+ def final_forward(self, logits, **kw_args):
335
+ return F.linear(logits, self.transformer.word_embeddings.weight).transpose(0, 1).contiguous()
336
+
337
+
338
+ class UntieFinalForwardMixin(BaseMixin):
339
+ def __init__(self, hidden_size, vocab_size, untie_head_num, layernorm_epsilon=1.0e-5):
340
+ super().__init__()
341
+
342
+ self.lm_head = nn.ModuleList()
343
+ for i in range(untie_head_num):
344
+ self.lm_head.append(
345
+ ColumnParallelLinear(
346
+ hidden_size,
347
+ 2 * hidden_size,
348
+ gather_output=True,
349
+ bias=False,
350
+ module=self,
351
+ name=f"lm_head.{i}",
352
+ )
353
+ ) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
354
+
355
+ self.head_layernorm = nn.ModuleList()
356
+ for i in range(untie_head_num):
357
+ self.head_layernorm.append(
358
+ LayerNorm(
359
+ hidden_size,
360
+ eps=layernorm_epsilon
361
+ )
362
+ )
363
+ self.activation_func=GEGLU()
364
+
365
+
366
+ def final_forward(self, logits, **kwargs):
367
+ logits = self.lm_head[1](logits)
368
+ logits = self.activation_func(logits)
369
+ logits = self.head_layernorm[1](logits)
370
+ return F.linear(logits, self.transformer.word_embeddings.weight).transpose(0, 1).contiguous()
371
+
372
+
373
+ class NonePositionEmbedding(BaseMixin):
374
+ def __init__(self):
375
+ super().__init__()
376
+
377
+ def position_embedding_forward(self, position_ids, output_cross_layer, **kw_args):
378
+ return None
379
+
380
+
381
+ class WordEmbedding(BaseMixin):
382
+ def __init__(self):
383
+ super().__init__()
384
+
385
+ def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
386
+ return self.transformer.word_embeddings(input_ids).transpose(0, 1)
387
+
388
+
389
+ class ProteinGLMForGeneration(BaseModel):
390
+ def __init__(self, args, transformer=None, **kwargs):
391
+ super().__init__(
392
+ args,
393
+ transformer=transformer,
394
+ **kwargs
395
+ )
396
+ self.add_mixin("glu-deepnorm", DeepNormWithGLUMixin(args.num_layers, args.hidden_size, args.inner_hidden_size))
397
+ self.add_mixin(
398
+ "fp32-softmax",
399
+ SelfAttentionWithFP32SoftmaxMixin(args.fp16, args.hidden_size, args.num_attention_heads, args.model_parallel_size),
400
+ )
401
+ if args.untie_head:
402
+ self.add_mixin("final-forward", UntieFinalForwardMixin(args.hidden_size, args.vocab_size, args.head_num))
403
+ else:
404
+ self.add_mixin("final-forward", FinalForwardMixin())
405
+ self.add_mixin("non-position-embedding", NonePositionEmbedding())
406
+ del self.transformer.position_embeddings
407
+ self.add_mixin("word-embedding", WordEmbedding())
408
+ self.add_mixin(
409
+ "rotary-embedding",
410
+ RotaryEmbeddingMixin(
411
+ args.fp16,
412
+ args.hidden_size,
413
+ args.num_attention_heads,
414
+ args.model_parallel_size,
415
+ args.rotary_embedding_2d
416
+ ),
417
+ )
418
+ self.get_mixin("glu-deepnorm").reinit()
419
+
420
+ @classmethod
421
+ def add_model_specific_args(cls, parser):
422
+ group = parser.add_argument_group('ProteinGLMForGeneration', 'ProteinGLMForGeneration Configurations')
423
+ group.add_argument('--untie-head', action='store_true', help='untie-heads')
424
+ group.add_argument('--head-num', default=1, type=int, help='head>1')
425
+ group.add_argument('--infer-type', default=1, type=int, help='1 for Generation')
426
+ group.add_argument('--rotary-embedding-2d', action='store_true',
427
+ help='If set, use 2D rotary embedding for ProtenGLM.')
428
+ return super().add_model_specific_args(parser)
msa_input ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ PPGPPGPPGKPGANGLSGERGPPGPPGPPG
2
+ SYEDQNSLLKMICQQVEAIKKEMQELKLNS<M>-AEDHKTILQMICQQVEALKNEMQEMKLNS<M>-AEDQKSLLQMICQQVEALKNEMHEMKLNS
3
+ MGSSHHHHHHSSGLVPRGSHMGAATPAERDAILLDLVRGQVAAVLGHASGEDIEPGRAFKNLGFDSLTAVELRDRLGAATGHKLPATIVFDYPNPTALAQHLRAAVL
4
+ MGSSHHHHHHSSGLVPRGSHMGAATPAERDAILLDLVRGQVAAVLGHASGEDIEPGRAFKNLGFDSLTAVELRDRLGAATGHKLPATIVFDYPNPTALAQHLRAAVL<M>-------------ITPSVESLRDLPRSERREALETLVVTEFKTALLMTEQDDLPLDESYFDLGLTSLTVNDLKQRLESLLSREIDGTLLFNSPTVQRLLDHLEEDV-
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ numpy==1.24.1
2
+ SwissArmyTransformer==0.4.11
3
+ torch==2.1.0.dev20230822+cu118
resources/app_case.png ADDED
resources/demo.gif ADDED

Git LFS Details

  • SHA256: 499be1fc8c44d53b5f717176630c60525202ac7367c3f5e3e94a3ab61b0d7da0
  • Pointer size: 132 Bytes
  • Size of remote file: 1.58 MB
resources/overall_frame.png ADDED
scripts/cli_sat.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ script_path=$(realpath $0)
4
+ script_dir=$(dirname $script_path)
5
+ main_dir=$(dirname $script_dir)
6
+
7
+ MP_SIZE=1
8
+ # MODEL_NAME="MSAGPT-"
9
+ # MODEL_NAME="MSAGPT-dpo"
10
+
11
+
12
+ SEED=12345
13
+ MAX_GEN_LENGTH=128
14
+ MIN_GEN_LENGTH=0
15
+
16
+ # BeamSearchStrategy args
17
+ NUM_BEAMS=4
18
+ LENGTH_PENALTY=1.0
19
+ NO_REPEAT_NGRAM=0
20
+
21
+ # BaseStrategy args
22
+ TEMP=0.8
23
+ TOPK=0
24
+ TOPP=0.9
25
+
26
+
27
+ PORT=19865
28
+
29
+ MODEL_ARGS="--bf16 \
30
+ --skip-init \
31
+ --mode finetune \
32
+ --rotary-embedding-2d"
33
+
34
+ # --mode inference \ TODO: sat ds_config bug?
35
+
36
+ GENERATION_ARGS="--seed $SEED \
37
+ --sampling-strategy BaseStrategy \
38
+ --max-gen-length $MAX_GEN_LENGTH \
39
+ --min-gen-length $MIN_GEN_LENGTH \
40
+ --num-beams $NUM_BEAMS \
41
+ --length-penalty $LENGTH_PENALTY \
42
+ --no-repeat-ngram-size $NO_REPEAT_NGRAM \
43
+ --multiline_stream \
44
+ --temperature $TEMP \
45
+ --top_k $TOPK \
46
+ --top_p $TOPP
47
+ "
48
+ # --sampling-strategy BeamSearchStrategy \
49
+ # --no-gap
50
+
51
+
52
+ OPTIONS_NCCL="NCCL_DEBUG=VERSION NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 CUDA_LAUNCH_BLOCKING=0"
53
+
54
+ ARGS="${main_dir}/cli_sat.py \
55
+ $MODEL_ARGS \
56
+ $GENERATION_ARGS \
57
+ $*"
58
+
59
+ run_cmd="${OPTIONS_NCCL} torchrun --nproc_per_node $MP_SIZE --master_port=$PORT ${ARGS}"
60
+ echo ${run_cmd}
61
+ eval ${run_cmd}
62
+ set +x
utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .strategies import AdvancedBaseStrategy, BeamSearchStrategy
2
+ from .tokenization import proteinglm_tokenizer
3
+ from .chat import chat_api
4
+ from .utils import move_cursor_up
utils/chat.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import stat
4
+ import re
5
+ import time
6
+ import argparse
7
+ import numpy as np
8
+
9
+ from functools import partial
10
+ from typing import List, Tuple
11
+
12
+ import torch.distributed as dist
13
+ from sat.helpers import print_rank0
14
+ from sat import mpu, get_args, get_tokenizer
15
+ from sat.generation.utils import timed_name, generate_continually
16
+ from sat.generation.autoregressive_sampling import update_mems, get_masks_and_position_ids_default
17
+
18
+ from .utils import move_cursor_up, move_cursor_down
19
+
20
+
21
+ def get_masks_and_position_ids(seq, msa_len, max_gen_length, gmask=False):
22
+ context_length = seq.shape[1]
23
+ query_len = msa_len
24
+ max_msa_num = (max_gen_length - 2) // query_len
25
+ max_gen_length = max_msa_num * query_len + 2
26
+ tokens = torch.nn.functional.pad(seq, (0, max_gen_length - context_length), mode="constant", value=-1)
27
+ attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
28
+ attention_mask.tril_()
29
+ attention_mask.unsqueeze_(1)
30
+ attention_mask = (attention_mask < 0.5).bool()
31
+ # <gMASK> + <SOP>
32
+ position_ids = np.zeros(max_gen_length, dtype=int)
33
+ block_position_ids = np.zeros(max_gen_length, dtype=int)
34
+ pre = 0
35
+ for msa_idx in range(max_msa_num):
36
+ position_ids[(1 + pre): (1 + pre + query_len)] = np.arange(query_len, dtype = int)
37
+ block_position_ids[(1 + pre): (1 + pre + query_len)] = msa_idx
38
+ pre += query_len
39
+ position_ids = np.stack((position_ids, block_position_ids), axis=0)
40
+ position_ids = torch.from_numpy(position_ids).to(tokens.device)
41
+ position_ids = position_ids.unsqueeze(0)
42
+ return tokens, attention_mask, position_ids
43
+
44
+
45
+
46
+ def generation_sequence(
47
+ model,
48
+ seqs,
49
+ strategy,
50
+ max_memory_length=100000,
51
+ get_masks_and_position_ids=get_masks_and_position_ids,
52
+ stream=False,
53
+ mems=None,
54
+ **kw_args
55
+ ):
56
+ '''
57
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
58
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
59
+ cache, should be first mems.shape[1] parts of context_tokens.
60
+ mems are the first-level citizens here, but we don't assume what is memorized.
61
+ input mems are used when multi-phase generation.
62
+ '''
63
+ assert len(seqs.shape) == 2
64
+ # building the initial tokens, attention_mask, and position_ids
65
+ batch_size, context_length = seqs.shape
66
+ seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
67
+ tokens = seqs[..., :context_length]
68
+ # initialize generation
69
+ counter = context_length # Last fixed index is ``counter''
70
+ index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
71
+ num_beams = 1
72
+ # step-by-step generation
73
+ while counter < seqs.shape[1] - 1:
74
+ # Now, we want to generate seq[counter + 1],
75
+ # token[:, index: counter+1] needs forwarding.
76
+ # forward
77
+ tokens = tokens.reshape(batch_size * num_beams, -1)
78
+ mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
79
+ model.eval()
80
+ with torch.no_grad():
81
+ logits, *output_per_layers = model(
82
+ tokens[:, index:],
83
+ position_ids[..., index: counter],
84
+ attention_mask[..., index: counter, :counter], # TODO memlen
85
+ mems=mems,
86
+ **kw_args
87
+ )
88
+ mem_kv = [o['mem_kv'] for o in output_per_layers]
89
+ mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
90
+ logits = logits[:, -1]
91
+ index = counter
92
+ counter += 1
93
+ logits = logits.reshape(batch_size, num_beams, -1)
94
+ tokens = tokens.reshape(batch_size, num_beams, -1)
95
+ mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
96
+ tokens, mems = strategy.forward(logits, tokens, mems)
97
+ if len(tokens.shape) == 3 and num_beams == 1:
98
+ num_beams = tokens.shape[1]
99
+ position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, 2, -1).reshape(batch_size * num_beams, 2, -1)
100
+ attention_mask_shape = attention_mask.shape[-3:]
101
+ attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
102
+ batch_size * num_beams, *attention_mask_shape)
103
+ if strategy.is_done:
104
+ break
105
+ return strategy.finalize(tokens, mems)
106
+
107
+
108
+ def stream_generation_sequence(
109
+ model,
110
+ seqs,
111
+ strategy,
112
+ max_memory_length=100000,
113
+ get_masks_and_position_ids=get_masks_and_position_ids,
114
+ stream=False,
115
+ mems=None,
116
+ **kw_args
117
+ ):
118
+ '''
119
+ seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
120
+ mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
121
+ cache, should be first mems.shape[1] parts of context_tokens.
122
+ mems are the first-level citizens here, but we don't assume what is memorized.
123
+ input mems are used when multi-phase generation.
124
+ '''
125
+ assert len(seqs.shape) == 2
126
+ # building the initial tokens, attention_mask, and position_ids
127
+ batch_size, context_length = seqs.shape
128
+ seqs, attention_mask, position_ids = get_masks_and_position_ids(seqs)
129
+ tokens = seqs[..., :context_length]
130
+ # initialize generation
131
+ counter = context_length # Last fixed index is ``counter''
132
+ index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
133
+ num_beams = 1
134
+ # step-by-step generation
135
+ while counter < seqs.shape[1] - 1:
136
+ # Now, we want to generate seq[counter + 1],
137
+ # token[:, index: counter+1] needs forwarding.
138
+ # forward
139
+ tokens = tokens.reshape(batch_size * num_beams, -1)
140
+ mems = mems.reshape(mems.shape[0], batch_size * num_beams, mems.shape[-2], mems.shape[-1]) if mems is not None else None
141
+ model.eval()
142
+ with torch.no_grad():
143
+ logits, *output_per_layers = model(
144
+ tokens[:, index:],
145
+ position_ids[..., index: counter],
146
+ attention_mask[..., index: counter, :counter], # TODO memlen
147
+ mems=mems,
148
+ **kw_args
149
+ )
150
+ mem_kv = [o['mem_kv'] for o in output_per_layers]
151
+ mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
152
+ logits = logits[:, -1]
153
+ index = counter
154
+ counter += 1
155
+ logits = logits.reshape(batch_size, num_beams, -1)
156
+ tokens = tokens.reshape(batch_size, num_beams, -1)
157
+ mems = mems.reshape(mems.shape[0], batch_size, num_beams, mems.shape[-2], mems.shape[-1])
158
+ tokens, mems = strategy.forward(logits, tokens, mems, is_first=False)
159
+ if len(tokens.shape) == 3 and num_beams == 1:
160
+ num_beams = tokens.shape[1]
161
+ position_ids = position_ids.unsqueeze(1).expand(batch_size, num_beams, 2, -1).reshape(batch_size * num_beams, 2, -1)
162
+ attention_mask_shape = attention_mask.shape[-3:]
163
+ attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_beams, -1, -1, -1).reshape(
164
+ batch_size * num_beams, *attention_mask_shape)
165
+ yield tokens, mems
166
+ if strategy.is_done:
167
+ break
168
+
169
+
170
+
171
+ def autoregressive_sampling(args, raw_text: str, model, tokenizer, strategy, stream=False) -> Tuple[List[str], List[str], List[List[str]]]:
172
+ # add MASK
173
+ generation_mask = "[gMASK]"
174
+ seq = []
175
+ msa_len = len(raw_text[0]) + 1
176
+ seq += [tokenizer.get_command(generation_mask)] + [tokenizer.get_command("sop")]
177
+ for each in raw_text:
178
+ seq += tokenizer.tokenize(each) + [tokenizer.get_command('<M>')]
179
+
180
+ output_list = [seq]
181
+ num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
182
+ seq = output_list[0]
183
+ # detect mask position
184
+ mask_token = tokenizer.get_command(generation_mask)
185
+ mask_position = seq.index(mask_token)
186
+
187
+ last_pos, answers, blanks, output_list = (
188
+ [0] * num_output,
189
+ ["" for _ in range(num_output)],
190
+ [[] for _ in range(num_output)],
191
+ []
192
+ )
193
+ icl_msas = len(raw_text)
194
+ input_seq = torch.tensor(
195
+ [seq],
196
+ dtype = torch.long,
197
+ device=args.device,
198
+ )
199
+ if args.stream_chat:
200
+ if args.chinese:
201
+ print(f"{'生成的MSA'.center(20, '*')}", flush=True)
202
+ else:
203
+ print(f"{'Virtual MSA'.center(20, '*')}", flush=True)
204
+ output_stream = stream_generation_sequence(
205
+ model = model,
206
+ seqs = input_seq,
207
+ strategy=strategy,
208
+ get_masks_and_position_ids=partial(
209
+ get_masks_and_position_ids,
210
+ msa_len = msa_len,
211
+ max_gen_length=args.max_gen_length,
212
+ gmask=True
213
+ )
214
+ )
215
+ offset = -1
216
+ for tmp_res, mems in output_stream:
217
+ if isinstance(tmp_res, torch.Tensor):
218
+ output = tmp_res.tolist()
219
+ output_list = output[0]
220
+ for i in range(len(output_list)):
221
+ output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
222
+ bog = output.index(tokenizer.get_command("sop"))
223
+ try:
224
+ unfinished = output.index(-1)
225
+ except ValueError:
226
+ unfinished = len(output)
227
+ output_list[i] = output[:mask_position] + output[bog + 1 : unfinished]
228
+ for i, output in enumerate(output_list):
229
+ if output[-1] == tokenizer.get_command("eos"):
230
+ output = output[:-1]
231
+ answers[i] = tokenizer.detokenize(output)
232
+ tmp_ret = answers[0] # only support streaming output first line.
233
+ if mpu.get_model_parallel_rank() == 0:
234
+ if not args.multiline_stream:
235
+ vit_msa = tmp_ret[offset if offset>0 else -1:]
236
+ print(vit_msa, end='', flush=True)
237
+ offset = len(tmp_ret)
238
+ else:
239
+ print_len = 0
240
+ vit_msa = tmp_ret.split('[<M>]')[icl_msas:]
241
+ vit_msa = [_ for _ in vit_msa if len(_) > 0]
242
+ for _ in vit_msa:
243
+ print(_)
244
+ print_len += 1
245
+ move_cursor_up(print_len)
246
+
247
+ move_cursor_down(print_len)
248
+ print('\n')
249
+ output = strategy.finalize(tmp_res, mems)[0]
250
+ else:
251
+ output, _ = generation_sequence(
252
+ model = model,
253
+ seqs = input_seq,
254
+ strategy=strategy,
255
+ get_masks_and_position_ids=partial(
256
+ get_masks_and_position_ids,
257
+ msa_len = msa_len,
258
+ max_gen_length=args.max_gen_length,
259
+ gmask=True
260
+ )
261
+ )
262
+ last_pos, answers, blanks, output_list = (
263
+ [0] * num_output,
264
+ ["" for _ in range(num_output)],
265
+ [[] for _ in range(num_output)],
266
+ []
267
+ )
268
+ if isinstance(output, torch.Tensor): # different strategies
269
+ output = output.tolist()
270
+ output = output[0] # batch_size = 1
271
+ output_list.extend(output)
272
+ # clip -1s and fill back generated things into seq
273
+ for i in range(len(output_list)):
274
+ output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
275
+ try:
276
+ unfinished = output.index(-1)
277
+ except ValueError:
278
+ unfinished = len(output)
279
+ # if output[unfinished - 1] in strategy.end_tokens:
280
+ # unfinished -= 1
281
+ bog = output.index(tokenizer.get_command("sop"))
282
+
283
+ prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
284
+ blank = tokenizer.detokenize(output[bog + 1 : unfinished])
285
+ blanks[i].append(blank)
286
+ last_pos[i] = mask_position + unfinished - (bog + 1)
287
+ output_list[i] = output[:mask_position] + output[bog + 1 : unfinished]
288
+
289
+
290
+ for i, output in enumerate(output_list):
291
+ if output[-1] == tokenizer.get_command("eos"):
292
+ output = output[:-1]
293
+ answers[i] = tokenizer.detokenize(output)
294
+ return answers
295
+
296
+
297
+ def offline_generation(args, temp, top_p, top_k, func):
298
+ os.makedirs(args.output_path, exist_ok=True)
299
+ with open(args.input_source, 'r', encoding="utf-8") as fin:
300
+ inputs = fin.readlines()
301
+ output_path = os.path.join(args.output_path, f"tmp_{temp}_p_{top_p}_k_{top_k}")
302
+ fin = open(output_path, 'w')
303
+ start_time = time.time()
304
+ for line_no, raw_text in enumerate(inputs):
305
+ if line_no % mpu.get_data_parallel_world_size() != mpu.get_data_parallel_rank():
306
+ continue
307
+ rk = dist.get_rank()
308
+ raw_text = raw_text.strip()
309
+ raw_text = raw_text.split('<M>')
310
+ main_seq = raw_text[0]
311
+
312
+ msa_len = len(main_seq) + 1
313
+ icl_msas = len(raw_text)
314
+ require_min_gen_length = msa_len * (icl_msas + 1) + 2
315
+ if args.max_gen_length < require_min_gen_length:
316
+ args.max_gen_length = require_min_gen_length # at least generate 1 msa.
317
+
318
+ if mpu.get_model_parallel_rank() == 0:
319
+ print(f'Processing No. {line_no} on model group {rk} input main seq: "{main_seq}" few-shot prompt: "{"<M>".join(raw_text[1:])}"')
320
+ if len(raw_text) == 0:
321
+ continue
322
+ ret = func(raw_text)
323
+ if mpu.get_model_parallel_rank() == 0:
324
+ if args.print_all_beams:
325
+ for idx, vit_msa in enumerate(ret):
326
+ vit_msa = vit_msa.split('[<M>]')[icl_msas:]
327
+ vit_msa = [_ for _ in vit_msa if len(_) > 0]
328
+ vit_msa_len = len(vit_msa)
329
+ vit_msa_str = '<M>'.join(vit_msa)
330
+ print('Beam: {} #Vitural Length:{} | MSA: "{}" | (Temp, P, K)=({}, {}, {}) | Taken time {:.2f}'.format(idx, vit_msa_len, vit_msa_str, temp, top_p, top_k, time.time() - start_time), flush=True)
331
+ else:
332
+ vit_msa = ret[0]
333
+ vit_msa = vit_msa.split('[<M>]')[icl_msas:]
334
+ vit_msa = [_ for _ in vit_msa if len(_) > 0]
335
+ vit_msa_len = len(vit_msa)
336
+ vit_msa_str = '<M>'.join(vit_msa)
337
+ fin.write(f"{vit_msa_str}"+'\n')
338
+ print('#Vitural Length:{} | MSA: "{}" | (Temp, P, K)=({}, {}, {}) | Taken time {:.2f}'.format(vit_msa_len, vit_msa_str, temp, top_p, top_k, time.time() - start_time), flush=True)
339
+ print()
340
+ fin.flush()
341
+ dist.barrier()
342
+ fin.close()
343
+
344
+
345
+ def online_generation(args, query, temp, top_p, top_k, func):
346
+ raw_text = query.strip()
347
+ raw_text = raw_text.split('<M>')
348
+ main_seq = raw_text[0]
349
+ msa_len = len(main_seq) + 1
350
+ icl_msas = len(raw_text)
351
+ require_min_gen_length = msa_len * (icl_msas + 1) + 2
352
+ if args.max_gen_length < require_min_gen_length:
353
+ args.max_gen_length = require_min_gen_length # at least generate 1 msa.
354
+ ret = func(raw_text)
355
+ response = []
356
+ if mpu.get_model_parallel_rank() == 0:
357
+ for idx, vit_msa in enumerate(ret):
358
+ vit_msa = vit_msa.split('[<M>]')[icl_msas:]
359
+ vit_msa = [_ for _ in vit_msa if len(_) > 0]
360
+ response.append(vit_msa)
361
+ return response
362
+
363
+
364
+ def chat_api(args, model, tokenizer, strategy, query=None): # TODO: Steam chat
365
+ if args.input_source == 'chat':
366
+ assert query is not None
367
+ ret = online_generation(args, query, temp=args.temperature, top_p = args.top_p, top_k = args.top_k, func = partial(autoregressive_sampling, args, model = model, tokenizer = tokenizer, strategy = strategy))
368
+ return ret
369
+ else:
370
+ assert not args.stream_chat, "Offline Generation don't support streaming output."
371
+ offline_generation(args, temp=args.temperature, top_p = args.top_p, top_k = args.top_k, func = partial(autoregressive_sampling, args, model = model, tokenizer = tokenizer, strategy = strategy))
utils/strategies.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from sat.generation.sampling_strategies.base_strategy import top_k_logits
5
+ from sat.mpu.initialize import get_model_parallel_world_size, get_model_parallel_src_rank, get_model_parallel_group
6
+
7
+ class AdvancedBaseStrategy:
8
+ def __init__(self, batch_size, invalid_slices=[], temperature=1., no_repeat_ngram_size = 0, top_k=200, eps=1e-4, top_p=0.0, min_gen_length=1, end_tokens=None):
9
+ self.batch_size = batch_size
10
+ self.invalid_slices = invalid_slices
11
+ self.temperature = temperature
12
+ self.topk = top_k
13
+ self.top_p = top_p
14
+ self.eps = eps
15
+ self.min_gen_length = min_gen_length
16
+ self.ngram=no_repeat_ngram_size
17
+ if end_tokens is None:
18
+ end_tokens = []
19
+ self.end_tokens = end_tokens
20
+ self.length_generated = 0
21
+ self.cached_beam_ngram_bans = [{} for _ in range(self.batch_size)]
22
+ self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
23
+ self._init_cache()
24
+
25
+ @property
26
+ def is_done(self) -> bool:
27
+ return self._is_done.all()
28
+
29
+ def _init_cache(self):
30
+ self.length_generated = 0
31
+ self.cached_beam_ngram_bans = [[{}] for _ in range(self.batch_size)]
32
+ self._is_done = np.zeros(self.batch_size, dtype=bool)
33
+
34
+
35
+ def forward(self, logits, tokens, mems, is_first = False, temperature=None):
36
+ # print(is_first)
37
+ batch_size, num_beam, seq_len = tokens.shape
38
+ seq_len = tokens.shape[-1]
39
+ if temperature is None:
40
+ temperature = self.temperature
41
+ logits = logits / temperature
42
+ if self.min_gen_length > self.length_generated:
43
+ for end_token in self.end_tokens:
44
+ logits[..., end_token] = -65504
45
+ for invalid_slice in self.invalid_slices:
46
+ logits[..., invalid_slice] = -65504
47
+ if self.ngram > 0 and seq_len > self.ngram:
48
+ for batch_idx in range(batch_size):
49
+ for i in range(num_beam):
50
+ ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
51
+ for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
52
+ logits[batch_idx, i, banned_index] = -65504
53
+ logits = logits.view(-1, logits.size(-1))
54
+ logits = top_k_logits(logits, self.topk, self.top_p)
55
+ probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
56
+
57
+ pred = torch.multinomial(probs, num_samples=1)
58
+ for i in range(self.batch_size):
59
+ if i >= batch_size:
60
+ self._is_done[i] = True
61
+ elif self._is_done[i]:
62
+ pred[i] = -1
63
+ elif pred[i].item() in self.end_tokens:
64
+ self._is_done[i] = True
65
+
66
+ if self.ngram > 0:
67
+ for batch_idx in range(batch_size):
68
+ bans_continue = []
69
+ for i in range(num_beam):
70
+ bans = self.cached_beam_ngram_bans[batch_idx][i].copy()
71
+ ngram_prefix = tuple(tokens[batch_idx, i, -(self.ngram - 1):].tolist())
72
+ bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (pred[batch_idx],)
73
+ bans_continue.append(bans)
74
+ self.cached_beam_ngram_bans[batch_idx] = bans_continue
75
+ tokens = torch.cat((tokens, pred.view(tokens.shape[:-1] + (1,))), dim=-1)
76
+ self.length_generated += 1
77
+
78
+ return tokens, mems
79
+
80
+ def finalize(self, tokens, mems):
81
+ self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
82
+ self._init_cache()
83
+ return tokens, mems
84
+
85
+
86
+ class BeamSearchStrategy:
87
+ def __init__(
88
+ self,
89
+ batch_size,
90
+ num_beams,
91
+ length_penalty=1.0,
92
+ consider_end=False,
93
+ end_tokens=[],
94
+ invalid_slices=[],
95
+ no_repeat_ngram_size=0,
96
+ min_gen_length=0,
97
+ deterministic=False,
98
+ ):
99
+ self.batch_size = batch_size
100
+ self.num_beams = num_beams
101
+ self.length_penalty = length_penalty
102
+ self.end_tokens = end_tokens
103
+ self.ngram = no_repeat_ngram_size
104
+ self.min_gen_length = min_gen_length
105
+ self.invalid_slices = invalid_slices
106
+ self.consider_end = consider_end
107
+ self.deterministic = deterministic
108
+ self._init_cache()
109
+
110
+ def _init_cache(self):
111
+ self.end_beams = [[] for _ in range(self.batch_size)] # list of LongTensors
112
+ self.end_beams_penalized_scores = [[] for _ in range(self.batch_size)] # list of LongTensors
113
+ self.cached_beam_scores = 0 # [batch_size]
114
+ self.cached_beam_ngram_bans = [[{} for _ in range(self.num_beams)] for _ in range(self.batch_size)]
115
+ self.length_generated = 0
116
+ self._is_done = np.zeros(self.batch_size, dtype=np.bool_)
117
+
118
+ def _add_end_beams(self, score, beam, batch_idx):
119
+ score = score / ((5.0 + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT
120
+ for i in range(len(self.end_beams[batch_idx]), -1, -1):
121
+ if i == 0 or score < self.end_beams_penalized_scores[batch_idx][i - 1]:
122
+ break
123
+ self.end_beams[batch_idx].insert(i, beam)
124
+ self.end_beams_penalized_scores[batch_idx].insert(i, score)
125
+
126
+ self.end_beams[batch_idx] = self.end_beams[batch_idx][: self.num_beams]
127
+ self.end_beams_penalized_scores[batch_idx] = self.end_beams_penalized_scores[batch_idx][: self.num_beams]
128
+
129
+ @property
130
+ def is_done(self) -> bool:
131
+ return self._is_done.all()
132
+
133
+ def forward(self, logits, tokens, mems):
134
+ batch_size, num_beams, vocab_size = logits.shape
135
+ seq_len = tokens.shape[-1]
136
+ logits = logits.float()
137
+ for invalid_slice in self.invalid_slices:
138
+ logits[..., invalid_slice] = -65504
139
+ if self.min_gen_length > self.length_generated:
140
+ for end_token in self.end_tokens:
141
+ logits[..., end_token] = -65504
142
+ if self.ngram > 0 and seq_len > self.ngram:
143
+ for batch_idx in range(batch_size):
144
+ for i in range(num_beams):
145
+ ngram_prefix = tokens[batch_idx, i, -(self.ngram - 1) :].tolist() # TODO ngram=1
146
+ for banned_index in self.cached_beam_ngram_bans[batch_idx][i].get(tuple(ngram_prefix), []):
147
+ logits[batch_idx, i, banned_index] = -65504
148
+
149
+ next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
150
+ prev_scores = self.cached_beam_scores
151
+ if isinstance(prev_scores, torch.Tensor):
152
+ prev_scores = prev_scores[..., None].expand_as(next_token_scores)
153
+ next_token_scores = next_token_scores + prev_scores
154
+
155
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
156
+
157
+ probs = F.softmax(next_token_scores, dim=-1)
158
+ if num_beams < self.num_beams: # First token
159
+ probs = probs[..., :vocab_size]
160
+ if self.deterministic:
161
+ next_tokens = torch.topk(probs, k=(max(1, len(self.end_tokens)) + 1) * self.num_beams).indices # [2*nb]
162
+ else:
163
+ next_tokens = torch.multinomial(
164
+ probs, num_samples=(max(1, len(self.end_tokens)) + 1) * self.num_beams
165
+ ) # [2*nb]
166
+ next_token_scores = next_token_scores[torch.arange(batch_size).unsqueeze(1), next_tokens]
167
+ next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
168
+ next_tokens = next_tokens[torch.arange(batch_size).unsqueeze(1), _indices]
169
+
170
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="trunc")
171
+ next_tokens = next_tokens % vocab_size
172
+
173
+ # select out end beams or continue beams
174
+ beam_continue_batch, score_continue_batch, mems_continue_batch = [], [], []
175
+ for batch_idx in range(batch_size):
176
+ beam_continue = []
177
+ scores_continue = []
178
+ bans_continue = []
179
+ mems_contiue = []
180
+ for i in range(len(next_tokens[batch_idx])):
181
+ beam = torch.cat((tokens[batch_idx, next_indices[batch_idx, i]], next_tokens[batch_idx, i : i + 1]))
182
+ if not self._is_done[batch_idx] and int(next_tokens[batch_idx, i]) in self.end_tokens:
183
+ self._add_end_beams(next_token_scores[batch_idx, i], beam, batch_idx)
184
+ elif len(beam_continue) < self.num_beams:
185
+ beam_continue.append(beam)
186
+ mems_contiue.append(mems[:, batch_idx, next_indices[batch_idx, i]])
187
+ # update caches
188
+ scores_continue.append(next_token_scores[batch_idx, i])
189
+ if self.ngram > 0:
190
+ bans = self.cached_beam_ngram_bans[batch_idx][next_indices[batch_idx, i]].copy()
191
+ # TODO ngram=1
192
+ ngram_prefix = tuple(tokens[batch_idx, next_indices[batch_idx, i], -(self.ngram - 1):].tolist())
193
+ bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[batch_idx, i],)
194
+ bans_continue.append(bans)
195
+ else:
196
+ break
197
+ beam_continue_batch.append(torch.stack(beam_continue))
198
+ mems_continue_batch.append(torch.stack(mems_contiue, dim=1))
199
+ score_continue_batch.append(scores_continue)
200
+ self.cached_beam_ngram_bans[batch_idx] = bans_continue
201
+ tokens = torch.stack(beam_continue_batch)
202
+ mems = torch.stack(mems_continue_batch, dim=1)
203
+ self.cached_beam_scores = torch.tensor(score_continue_batch, device=logits.device)
204
+ self.length_generated += 1
205
+ for batch_idx in range(self.batch_size):
206
+ if batch_idx >= batch_size:
207
+ self._is_done[batch_idx] = True
208
+ elif (
209
+ len(self.end_beams[batch_idx]) == self.num_beams
210
+ and self.end_beams_penalized_scores[batch_idx][-1]
211
+ >= self.cached_beam_scores[batch_idx].max() / ((5.0 + (seq_len + 1)) / 6) ** self.length_penalty
212
+ ): # We're done if none of current tokens will better than the worst in end_beams
213
+ self._is_done[batch_idx] = True
214
+
215
+ return tokens, mems
216
+
217
+ def finalize(self, tokens, mems):
218
+ if self.consider_end:
219
+ batch_size, num_beams = tokens.shape[:2]
220
+ for batch_idx in range(batch_size):
221
+ if not self._is_done[batch_idx]:
222
+ for i in range(num_beams):
223
+ self._add_end_beams(self.cached_beam_scores[batch_idx, i], tokens[batch_idx, i], batch_idx)
224
+ mems = None
225
+ ret = self.end_beams[:batch_size]
226
+ else:
227
+ ret = tokens
228
+ self._init_cache()
229
+ return ret, mems
utils/tokenization.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Tuple, List, Union
2
+ import itertools
3
+
4
+ class ResidueLevelTokenizer:
5
+ """
6
+ Tokenizer for Protein Residue Level Tokenization.
7
+ """
8
+
9
+ def __init__(self, **kwargs):
10
+ super(ResidueLevelTokenizer, self).__init__()
11
+ self.pad_tok = ['[pad]']
12
+ self.all_toks = self.pad_tok
13
+ self._tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
14
+ self.all_toks.extend(self._tokens)
15
+ self._special_tokens = ['MASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '</s>', '<M>']
16
+ self.set_special_tokens(self._special_tokens)
17
+ self.special_tokens['eos']=self.special_tokens['</s>']
18
+ self.special_tokens['tMASK']=self.special_tokens['MASK']
19
+
20
+ self.all_toks.extend(self._special_tokens)
21
+ self._vocab = {t: i for i, t in enumerate(self.all_toks)}
22
+ self.command_token = {'[tMASK]': 'tMASK', '[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'}
23
+ # print('Building vocab.: {}'.format(self._vocab))
24
+ # print('Special_tokens: {}'.format(self.special_tokens))
25
+ # print('All tokens: {}'.format(self.all_toks))
26
+
27
+ def pad_id(self):
28
+ return self._vocab['[pad]']
29
+
30
+ def set_special_tokens(self, special_tokens):
31
+ """Add a list of additional tokens to the encoder.
32
+ The additional tokens are indexed starting from the last index of the
33
+ current vocabulary in the order of the `special_tokens` list.
34
+ """
35
+ if not special_tokens:
36
+ self.special_tokens = {}
37
+ self.special_tokens_decoder = {}
38
+ return
39
+ self.special_tokens = dict((tok, len(self.all_toks) + i) for i, tok in enumerate(special_tokens))
40
+ self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
41
+
42
+
43
+ def __len__(self):
44
+ return len(self._vocab)
45
+
46
+
47
+ def EncodeAsIds(self, text, process_fn=None):
48
+ """convert sequence to idx"""
49
+ processed_text = text
50
+ if process_fn is not None:
51
+ processed_text = process_fn(processed_text)
52
+ processed_text = str(processed_text)
53
+ tokens = [self.TokenToId(c) for c in processed_text]
54
+ return tokens
55
+
56
+ def IdToToken(self, idx):
57
+ if idx == 0:
58
+ return '[pad]'
59
+ elif idx in self.special_tokens_decoder:
60
+ return f"[{self.special_tokens_decoder[idx]}]"
61
+ else:
62
+ try:
63
+ tok = self.all_toks[idx]
64
+ except:
65
+ tok = '*'
66
+ return tok
67
+ def TokenToId(self, token):
68
+ if token == '[pad]':
69
+ return 0
70
+ elif token in self.special_tokens:
71
+ return self.special_tokens[token]
72
+ else:
73
+ return self._vocab[token]
74
+
75
+ def DecodeIds(self, Ids):
76
+ return ''.join([self.IdToToken(tok) for tok in Ids])
77
+
78
+ def _tokenize(self, text) -> str:
79
+ return text.split()
80
+
81
+ def tokenize(self, text, **kwargs) -> List[str]:
82
+ """
83
+ Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
84
+ Converts a string in a sequence of tokens, using the tokenizer.
85
+
86
+ Args:
87
+ text (:obj:`str`):
88
+ The sequence to be encoded.
89
+
90
+ Returns:
91
+ :obj:`List[str]`: The list of tokens.
92
+ """
93
+
94
+ def split_on_token(tok, text):
95
+ result = []
96
+ split_text = text.split(tok)
97
+ for i, sub_text in enumerate(split_text):
98
+ # AddedToken can control whitespace stripping around them.
99
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
100
+ # Cf. https://github.com/huggingface/transformers/pull/2778
101
+ # and https://github.com/huggingface/transformers/issues/3788
102
+ # We strip left and right by default
103
+ if i < len(split_text) - 1:
104
+ sub_text = sub_text.rstrip()
105
+ if i > 0:
106
+ sub_text = sub_text.lstrip()
107
+
108
+ if i == 0 and not sub_text:
109
+ result.append(tok)
110
+ elif i == len(split_text) - 1:
111
+ if sub_text:
112
+ result.append(sub_text)
113
+ else:
114
+ pass
115
+ else:
116
+ if sub_text:
117
+ result.append(sub_text)
118
+ result.append(tok)
119
+ return result
120
+
121
+ def split_on_tokens(tok_list, text):
122
+ if not text.strip():
123
+ return []
124
+
125
+ tokenized_text = []
126
+ text_list = [text]
127
+ for tok in tok_list:
128
+ tokenized_text = []
129
+ for sub_text in text_list:
130
+ if sub_text not in self._tokens:
131
+ tokenized_text.extend(split_on_token(tok, sub_text))
132
+ else:
133
+ tokenized_text.append(sub_text)
134
+ text_list = tokenized_text
135
+
136
+ return list(
137
+ itertools.chain.from_iterable(
138
+ (
139
+ self._tokenize(token)
140
+ if token not in self.all_toks
141
+ else [token]
142
+ for token in tokenized_text
143
+ )
144
+ )
145
+ )
146
+ no_split_token = self.all_toks
147
+ tokenized_text = split_on_tokens(no_split_token, text)
148
+ return self.convert_tokens_to_ids(tokenized_text)
149
+
150
+ def convert_tokens_to_ids(self, tokens):
151
+ """Converts a sequence of tokens into ids using the vocab."""
152
+ ids = []
153
+ # print_rank_0(tokens)
154
+ # print_rank_0(self.vocab)
155
+ for token in tokens:
156
+ ids.append(self.TokenToId(token))
157
+ return ids
158
+
159
+
160
+ class proteinglm_tokenizer:
161
+ """
162
+ Protein Tokenizer based on Residue level tokenizer
163
+ """
164
+
165
+ def __init__(self):
166
+ name = 'ProteinTokenizer'
167
+ self.tokenizer = ResidueLevelTokenizer()
168
+ self.special_tokens = self.tokenizer.special_tokens
169
+
170
+
171
+ def IdToToken(self, idx):
172
+ return self.tokenizer.IdToToken(idx)
173
+
174
+ def TokenToId(self, token):
175
+ return self.tokenizer.TokenToId(token)
176
+
177
+ @property
178
+ def vocab_size(self):
179
+ return len(self.tokenizer)
180
+
181
+ def decode(self, token_ids):
182
+ return self.tokenizer.DecodeIds([token_ids])
183
+
184
+ @property
185
+ def eod(self):
186
+ return self.tokenizer.get_special_token('eos')
187
+
188
+ def detokenize(self, Ids, type_token=False):
189
+ new_tokens = self.tokenizer.DecodeIds(Ids)
190
+ return new_tokens
191
+
192
+ def tokenize(self, text):
193
+ ids = self.tokenizer.tokenize(text)
194
+ return ids
195
+
196
+ @property
197
+ def vocab(self):
198
+ return self.tokenizer._vocab
199
+
200
+ @property
201
+ def inv_vocab(self):
202
+ return {v:k for k, v in self.tokenizer._vocab.items()}
203
+
204
+ @property
205
+ def get_pad_id(self):
206
+ return self.tokenizer.pad_id
207
+
208
+
209
+ def get_command(self, token):
210
+ tok = token
211
+ if token in self.tokenizer.command_token:
212
+ tok = self.tokenizer.command_token[token]
213
+ return self.tokenizer.special_tokens[tok]
utils/utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ def move_cursor_up(n):
2
+ # ANSI escape code to move cursor up by n lines
3
+ print(f"\033[{n}A", end='')
4
+
5
+ def move_cursor_down(n):
6
+ # ANSI escape code to move cursor down by n lines
7
+ print(f"\033[{n}B", end='')