nancyH commited on
Commit
ab6c03c
·
verified ·
1 Parent(s): b0c4b1b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +17 -0
  3. LICENSE +201 -0
  4. README.md +380 -0
  5. SNP/SNP.py +85 -0
  6. SNP/example_mut_file.txt +6 -0
  7. SNP/examples/dev.tsv +6 -0
  8. SNP/mutate_seqs.py +118 -0
  9. examples/.Rhistory +0 -0
  10. examples/.run_pretrain.py.swp +0 -0
  11. examples/6mer_pretrain_emb/static_6mer_embeddings.npy +3 -0
  12. examples/6mer_pretrain_emb_20ways/static_6mer_embed_20ways.npy +3 -0
  13. examples/6mer_pretrain_emb_adaptive/static_adaptive_embed.npy +3 -0
  14. examples/compute_result.py +290 -0
  15. examples/data_process_template/.process_pretrain_data_multi.py.swp +0 -0
  16. examples/data_process_template/process_690.py +103 -0
  17. examples/data_process_template/process_csv.py +311 -0
  18. examples/data_process_template/process_finetune_data.py +713 -0
  19. examples/data_process_template/process_ner.py +132 -0
  20. examples/data_process_template/process_pretrain_data.py +148 -0
  21. examples/data_process_template/process_pretrain_data_multi.py +63 -0
  22. examples/data_process_template/process_scan_prom_data.py +76 -0
  23. examples/gen_cCRE_emb_final.py +113 -0
  24. examples/load_model_test.py +69 -0
  25. examples/requirements.txt +11 -0
  26. examples/run_finetune.py +1284 -0
  27. examples/run_pretrain.py +885 -0
  28. examples/run_pretrain.sh.save +36 -0
  29. examples/sample_data/ft/6/dev.tsv +0 -0
  30. examples/sample_data/ft/6/train.tsv +3 -0
  31. examples/sample_data/pre/6_3k.txt +0 -0
  32. examples/save_static_embeddings.py +65 -0
  33. examples/scripts/run_mut.sh +45 -0
  34. examples/scripts/uce.sh +26 -0
  35. examples/visualize.py +152 -0
  36. motif/find_motifs.py +112 -0
  37. motif/motif_utils.py +553 -0
  38. save2cache.py +224 -0
  39. setup.cfg +36 -0
  40. setup.py +127 -0
  41. src/transformers/__init__.py +436 -0
  42. src/transformers/activations.py +48 -0
  43. src/transformers/commands/__init__.py +13 -0
  44. src/transformers/commands/convert.py +144 -0
  45. src/transformers/commands/download.py +32 -0
  46. src/transformers/commands/env.py +58 -0
  47. src/transformers/commands/run.py +96 -0
  48. src/transformers/commands/serving.py +214 -0
  49. src/transformers/commands/train.py +144 -0
  50. src/transformers/commands/user.py +209 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/sample_data/ft/6/train.tsv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ cache*
3
+ dna_cache*
4
+ examples/runs
5
+ examples/ft
6
+ examples/output*
7
+ examples/ft_new
8
+ examples/results
9
+ examples/data_old
10
+ examples/data
11
+ examples/result
12
+ examples/models
13
+ src/transformers/data/__pycache__
14
+ src/transformers/data/metrics/__pycache__
15
+ src/transformers/data/processors/__pycache__
16
+ src/transformers/__pycache__
17
+ src/transformers.egg-info
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DNABERT
2
+ This repository includes the implementation of 'DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome'. Please cite our paper if you use the models or codes. The repo is still actively under development, so please kindly report if there is any issue encountered.
3
+
4
+ In this package, we provides resources including: source codes of the DNABERT model, usage examples, pre-trained models, fine-tuned models and visulization tool. This package is still under development, as more features will be included gradually. Training of DNABERT consists of general-purposed pre-training and task-specific fine-tuning. As a contribution of our project, we released the pre-trained models in this repository. We extended codes from [huggingface](https://github.com/huggingface/transformers) and adapted them to the DNA scenario.
5
+
6
+ ## Update 2025/07/08
7
+
8
+ The original links to the pretrained DNABERT models (DNABERT-3, 4, 5, 6) have expired. Please go to HuggingFace to access and download the models:
9
+
10
+ DNABERT-3: https://huggingface.co/zhihan1996/DNA_bert_3
11
+ DNABERT-4: https://huggingface.co/zhihan1996/DNA_bert_4
12
+ DNABERT-5: https://huggingface.co/zhihan1996/DNA_bert_5
13
+ DNABERT-6: https://huggingface.co/zhihan1996/DNA_bert_6
14
+
15
+ ## Update 2023/06/26
16
+
17
+ The second generation of DNABERT, named [DNABERT-2](https://arxiv.org/abs/2306.15006), is publically available at https://github.com/Zhihan1996/DNABERT_2. DNABERT-2 is trained on multi-species genomes and is more efficient, powerful, and easy to use than its first generation. We also provide simpler usage of DNABERT in the new package. A comprehensive benchmark Genome Understanding Evaluation (GUE), which contains $28$ datasets on $7$ tasks, is also published. Please check out DNABERT-2 if you are interested in our work. Thanks!
18
+
19
+
20
+ ## Citation
21
+ If you have used DNABERT in your research, please kindly cite the following publications:
22
+
23
+ ```
24
+ @article{ji2021dnabert,
25
+ author = {Ji, Yanrong and Zhou, Zhihan and Liu, Han and Davuluri, Ramana V},
26
+ title = "{DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome}",
27
+ journal = {Bioinformatics},
28
+ volume = {37},
29
+ number = {15},
30
+ pages = {2112-2120},
31
+ year = {2021},
32
+ month = {02},
33
+ issn = {1367-4803},
34
+ doi = {10.1093/bioinformatics/btab083},
35
+ url = {https://doi.org/10.1093/bioinformatics/btab083},
36
+ eprint = {https://academic.oup.com/bioinformatics/article-pdf/37/15/2112/50578892/btab083.pdf},
37
+ }
38
+
39
+
40
+ @misc{zhou2023dnabert2,
41
+ title={DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome},
42
+ author={Zhihan Zhou and Yanrong Ji and Weijian Li and Pratik Dutta and Ramana Davuluri and Han Liu},
43
+ year={2023},
44
+ eprint={2306.15006},
45
+ archivePrefix={arXiv},
46
+ primaryClass={q-bio.GN}
47
+ }
48
+ ```
49
+
50
+
51
+ ## 1. Environment setup
52
+
53
+ We recommend you to build a python virtual environment with [Anaconda](https://docs.anaconda.com/anaconda/install/linux/). Also, please make sure you have at least one NVIDIA GPU with Linux x86_64 Driver Version >= 410.48 (compatible with CUDA 10.0). We applied distributed training on 8 NVIDIA GeForce RTX 2080 Ti with 11 GB graphic memory, and the batch size corresponds to it. If you use GPU with other specifications and memory sizes, consider adjusting your batch size accordingly.
54
+
55
+ #### 1.1 Create and activate a new virtual environment
56
+
57
+ ```
58
+ conda create -n dnabert python=3.6
59
+ conda activate dnabert
60
+ ```
61
+
62
+
63
+
64
+ #### 1.2 Install the package and other requirements
65
+
66
+ (Required)
67
+
68
+ ```
69
+ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
70
+
71
+ git clone https://github.com/jerryji1993/DNABERT
72
+ cd DNABERT
73
+ python3 -m pip install --editable .
74
+ cd examples
75
+ python3 -m pip install -r requirements.txt
76
+ ```
77
+
78
+
79
+
80
+ (Optional, install apex for fp16 training)
81
+
82
+ change to a desired directory by `cd PATH_NAME`
83
+
84
+ ```
85
+ git clone https://github.com/NVIDIA/apex
86
+ cd apex
87
+ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
88
+ ```
89
+
90
+
91
+
92
+
93
+
94
+ ## 2. Pre-train (Skip this section if you fine-tune on pre-trained models)
95
+
96
+ #### 2.1 Data processing
97
+
98
+ Please see the template data at `/example/sample_data/pre`. If you are trying to pre-train DNABERT with your own data, please process you data into the same format as it. Note that the sequences are in kmer format, so you will need to convert your sequences into that. We also provide a custom function `seq2kmer`in `motif/motif_utils.py` for this conversion.
99
+
100
+
101
+
102
+ In the following example, we use DNABERT with kmer=6 as example.
103
+
104
+
105
+
106
+ #### 2.2 Model Training
107
+
108
+ ```
109
+ cd examples
110
+
111
+ export KMER=6
112
+ export TRAIN_FILE=sample_data/pre/6_3k.txt
113
+ export TEST_FILE=sample_data/pre/6_3k.txt
114
+ export SOURCE=PATH_TO_DNABERT_REPO
115
+ export OUTPUT_PATH=output$KMER
116
+
117
+ python run_pretrain.py \
118
+ --output_dir $OUTPUT_PATH \
119
+ --model_type=dna \
120
+ --tokenizer_name=dna$KMER \
121
+ --config_name=$SOURCE/src/transformers/dnabert-config/bert-config-$KMER/config.json \
122
+ --do_train \
123
+ --train_data_file=$TRAIN_FILE \
124
+ --do_eval \
125
+ --eval_data_file=$TEST_FILE \
126
+ --mlm \
127
+ --gradient_accumulation_steps 25 \
128
+ --per_gpu_train_batch_size 10 \
129
+ --per_gpu_eval_batch_size 6 \
130
+ --save_steps 500 \
131
+ --save_total_limit 20 \
132
+ --max_steps 200000 \
133
+ --evaluate_during_training \
134
+ --logging_steps 500 \
135
+ --line_by_line \
136
+ --learning_rate 4e-4 \
137
+ --block_size 512 \
138
+ --adam_epsilon 1e-6 \
139
+ --weight_decay 0.01 \
140
+ --beta1 0.9 \
141
+ --beta2 0.98 \
142
+ --mlm_probability 0.025 \
143
+ --warmup_steps 10000 \
144
+ --overwrite_output_dir \
145
+ --n_process 24
146
+ ```
147
+
148
+ Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
149
+
150
+
151
+
152
+
153
+
154
+ ## 3. Fine-tune (Skip this section if you use fine-tuned model)
155
+
156
+ #### 3.1 Data processing
157
+
158
+ Please see the template data at `/example/sample_data/ft/`. If you are trying to fine-tune DNABERT with your own data, please process you data into the same format as it. Note that the sequences are in kmer format, so you will need to convert your sequences into that. We also provide a custom function `seq2kmer`in `motif/motif_utils.py` for this conversion.
159
+
160
+
161
+
162
+ #### 3.2 Download pre-trained DNABERT
163
+
164
+ [DNABERT3](https://drive.google.com/file/d/1nVBaIoiJpnwQxiz4dSq6Sv9kBKfXhZuM/view?usp=sharing)
165
+
166
+ [DNABERT4](https://drive.google.com/file/d/1V7CChcC6KgdJ7Gwdyn73OS6dZR_J-Lrs/view?usp=sharing)
167
+
168
+ [DNABERT5](https://drive.google.com/file/d/1KMqgXYCzrrYD1qxdyNWnmUYPtrhQqRBM/view?usp=sharing)
169
+
170
+ [DNABERT6](https://drive.google.com/file/d/1BJjqb5Dl2lNMg2warsFQ0-Xvn1xxfFXC/view?usp=sharing)
171
+
172
+ Download the pre-trained model in to a directory. (If you would like to replicate the following examples, please download DNABERT 6). Then unzip the package by running:
173
+
174
+ ```
175
+ unzip 6-new-12w-0.zip
176
+ ```
177
+
178
+ We also provide a model with `KMER=6` that is fine-tuned on the sample dataset for prediction/visulization/motif_analysis. If you use the fine-tuned model instead of fine-tuning a model by your self, please download the fine-tuned and put it under `examples/ft/6`.
179
+
180
+ [Fine-tuned Model](https://drive.google.com/drive/folders/15wFcukTv3ecPw9_25dcOv-bZmj-8d_-6?usp=sharing)
181
+
182
+
183
+ #### 3.3 Fine-tune with pre-trained model
184
+
185
+ In the following example, we use DNABERT with kmer=6 as example. We use `prom-core`, a 2-class classification task as example.
186
+
187
+ ```
188
+ cd examples
189
+
190
+ export KMER=6
191
+ export MODEL_PATH=PATH_TO_THE_PRETRAINED_MODEL
192
+ export DATA_PATH=sample_data/ft/$KMER
193
+ export OUTPUT_PATH=./ft/$KMER
194
+
195
+ python run_finetune.py \
196
+ --model_type dna \
197
+ --tokenizer_name=dna$KMER \
198
+ --model_name_or_path $MODEL_PATH \
199
+ --task_name dnaprom \
200
+ --do_train \
201
+ --do_eval \
202
+ --data_dir $DATA_PATH \
203
+ --max_seq_length 100 \
204
+ --per_gpu_eval_batch_size=32 \
205
+ --per_gpu_train_batch_size=32 \
206
+ --learning_rate 2e-4 \
207
+ --num_train_epochs 5.0 \
208
+ --output_dir $OUTPUT_PATH \
209
+ --evaluate_during_training \
210
+ --logging_steps 100 \
211
+ --save_steps 4000 \
212
+ --warmup_percent 0.1 \
213
+ --hidden_dropout_prob 0.1 \
214
+ --overwrite_output \
215
+ --weight_decay 0.01 \
216
+ --n_process 8
217
+ ```
218
+
219
+ Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
220
+
221
+ We also provide a model with `KMER=6` that is fine-tuned on the sample dataset for prediction/visulization/motif_analysis. If you use the fine-tuned model instead of fine-tuning a model by your self, please download the fine-tuned and put it under `examples/ft/6`.
222
+
223
+ [Fine-tuned Model](https://drive.google.com/drive/folders/15wFcukTv3ecPw9_25dcOv-bZmj-8d_-6?usp=sharing)
224
+
225
+
226
+
227
+ ## 4. Prediction
228
+
229
+ After the model is fine-tuned, we can get predictions by running
230
+
231
+ ```$
232
+ export KMER=6
233
+ export MODEL_PATH=./ft/$KMER
234
+ export DATA_PATH=sample_data/ft/$KMER
235
+ export PREDICTION_PATH=./result/$KMER
236
+
237
+ python run_finetune.py \
238
+ --model_type dna \
239
+ --tokenizer_name=dna$KMER \
240
+ --model_name_or_path $MODEL_PATH \
241
+ --task_name dnaprom \
242
+ --do_predict \
243
+ --data_dir $DATA_PATH \
244
+ --max_seq_length 75 \
245
+ --per_gpu_pred_batch_size=128 \
246
+ --output_dir $MODEL_PATH \
247
+ --predict_dir $PREDICTION_PATH \
248
+ --n_process 48
249
+ ```
250
+
251
+ With the above command, the fine-tuned DNABERT model will be loaded from `MODEL_PATH` , and makes prediction on the `dev.tsv` file that saved in `DATA_PATH` and save the prediction result at `PREDICTION_PATH`.
252
+
253
+
254
+ Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
255
+
256
+
257
+ ## 5. Visualization
258
+
259
+ Visualiazation of DNABERT consists of 2 steps. Calcualate attention scores and Plot.
260
+
261
+ #### 5.1 Calculate attention scores
262
+
263
+ calculate with only one model (For example, DNABERT6)
264
+
265
+ ```
266
+ export KMER=6
267
+ export MODEL_PATH=./ft/$KMER
268
+ export DATA_PATH=sample_data/ft/$KMER
269
+ export PREDICTION_PATH=./result/$KMER
270
+
271
+ python run_finetune.py \
272
+ --model_type dna \
273
+ --tokenizer_name=dna$KMER \
274
+ --model_name_or_path $MODEL_PATH \
275
+ --task_name dnaprom \
276
+ --do_visualize \
277
+ --visualize_data_dir $DATA_PATH \
278
+ --visualize_models $KMER \
279
+ --data_dir $DATA_PATH \
280
+ --max_seq_length 81 \
281
+ --per_gpu_pred_batch_size=16 \
282
+ --output_dir $MODEL_PATH \
283
+ --predict_dir $PREDICTION_PATH \
284
+ --n_process 96
285
+ ```
286
+
287
+ With the above command, the fine-tuned DNABERT model will be loaded from `MODEL_PATH` , and calculates attention scores on the `dev.tsv` file that saved in `DATA_PATH` and save the result at `PREDICTION_PATH`.
288
+
289
+ Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
290
+
291
+ ####5.2 Plotting tool
292
+
293
+ ## 6. Motif analysis
294
+
295
+ Once the attention scores are generated, we can proceed further to perform motif analysis using `motif/find_motifs.py`:
296
+
297
+ ```
298
+ cd ../motif
299
+
300
+ export KMER=6
301
+ export DATA_PATH=../examples/sample_data/ft/$KMER
302
+ export PREDICTION_PATH=../examples/result/$KMER
303
+ export MOTIF_PATH=./result/$KMER
304
+
305
+ python find_motifs.py \
306
+ --data_dir $DATA_PATH \
307
+ --predict_dir $PREDICTION_PATH \
308
+ --window_size 24 \
309
+ --min_len 5 \
310
+ --pval_cutoff 0.005 \
311
+ --min_n_motif 3 \
312
+ --align_all_ties \
313
+ --save_file_dir $MOTIF_PATH \
314
+ --verbose
315
+ ```
316
+
317
+ The script will generate a .txt file and a weblogo .png file for each motif under `MOTIF_PATH`.
318
+
319
+ ## 7. Genomic variants analysis
320
+
321
+ To perform genomic variants analysis (e.g. SNPs), we need to first ensure the predictions for the sequences were generated. Then, create a file (template in `SNP/example_mut_file.txt`) specifying for which sequences in `dev.tsv` and start and end indices where we need to perform the mutation. The first column indicates the index of sequence in `dev.tsv` to be mutated. Second and third columns are the start and end indices while the fourth column is the target of mutation (can be substitution, insertion, deletion, etc.)
322
+
323
+ Once such a file is created, we can perform mutation on the sequences:
324
+
325
+ ```
326
+ cd ../SNP
327
+ python mutate_seqs.py ./../examples/sample_data/ft/6/dev.tsv ./examples/ --mut_file ./example_mut_file.txt --k 6
328
+ ```
329
+ Alternatively, we can choose to leave the `--mut_file` argument blank, where the program would try to perform substitution of all bases to the four possible nucleotides ('A', 'T', 'C', or 'G') for all sequences. This would be useful for plotting a mutation heatmap as included in the paper. **Note that this would be slow if the `dev.tsv` contains a lot of sequences or the input sequences are very long, as the command would try to perform mutation on all possible locations of them**.
330
+
331
+ ```
332
+ cd ../SNP
333
+ python mutate_seqs.py ./../examples/sample_data/ft/6/dev.tsv ./examples/ --k 6
334
+ ```
335
+
336
+ After that, we can again predict on the generated sequences. **Note: if you have insertion/deletions in your `mut_file.txt`, consider changing the `max_seq_length` we use when making predictions.**
337
+
338
+ ```
339
+ export KMER=6
340
+ export MODEL_PATH=../examples/ft/$KMER
341
+ export DATA_PATH=examples
342
+ export PREDICTION_PATH=examples
343
+
344
+ python ../examples/run_finetune.py \
345
+ --model_type dna \
346
+ --tokenizer_name=dna$KMER \
347
+ --model_name_or_path $MODEL_PATH \
348
+ --task_name dnaprom \
349
+ --do_predict \
350
+ --data_dir $DATA_PATH \
351
+ --max_seq_length 75 \
352
+ --per_gpu_pred_batch_size=128 \
353
+ --output_dir $MODEL_PATH \
354
+ --predict_dir $PREDICTION_PATH \
355
+ --n_process 48
356
+ ```
357
+
358
+ This will again create `pred_results.npy` file under the `$PREDICTION_PATH`. Once we have all the above, we can compute the effect of these mutations by:
359
+
360
+ ```
361
+ python SNP.py \
362
+ --orig_seq_file ../examples/sample_data/ft/6/dev.tsv \
363
+ --orig_pred_file ../examples/result/6/pred_results.npy \
364
+ --mut_seq_file examples/dev.tsv \
365
+ --mut_pred_file examples/pred_results.npy \
366
+ --save_file_dir examples
367
+ ```
368
+
369
+ This would save a `mutations.tsv` file under `save_file_dir`, that contains index of original sequence (in original `dev.tsv`), original sequence and predictions, mutated sequence and predictions, as well as the difference score and log odds ratio of the change in every case.
370
+
371
+
372
+ ## Q&A
373
+
374
+ #### 1. I cannot start training the model/I have installation issues for the dependencies.
375
+
376
+ Please kindly make sure that you satisfied all system requirements for DNABERT, and that you have a conda environment properly set up. We have recently successfully tested our pipeline on Amazon EC2 Deep Learning AMI (Ubuntu 18.04). As an option, you could compare your system/environment setup with this AMI.
377
+
378
+ #### 2. Can DNABERT run on sequences longer than 512?
379
+
380
+ #### 3. Can DNABERT be extended to multi-class classification?
SNP/SNP.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### ::: DNABERT-viz SNP analysis ::: ####
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append('../motif')
6
+ import pandas as pd
7
+ import numpy as np
8
+ import argparse
9
+ import motif_utils as utils
10
+
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument(
15
+ "--orig_seq_file",
16
+ default='../examples/sample_data/ft/prom-core/6/dev.tsv',
17
+ type=str,
18
+ required=True,
19
+ help="Path to original input sequence+label .tsv file.",
20
+ )
21
+
22
+ parser.add_argument(
23
+ "--orig_pred_file",
24
+ required=True,
25
+ type=str,
26
+ default='../examples/result/prom-core/6/pred.npy',
27
+ help="Path to predictions pred.npy of original sequences.",
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--mut_seq_file",
32
+ default='examples/dev.tsv',
33
+ type=str,
34
+ required=True,
35
+ help="Path to mutated sequence+index .tsv file.",
36
+ )
37
+
38
+ parser.add_argument(
39
+ "--mut_pred_file",
40
+ required=True,
41
+ type=str,
42
+ default='examples/pred.npy',
43
+ help="Path to predictions pred_results.npy of mutated sequences.",
44
+ )
45
+
46
+ parser.add_argument(
47
+ "--save_file_dir",
48
+ default='.',
49
+ type=str,
50
+ help="Path to save outputs",
51
+ )
52
+
53
+ # TODO: add the conditions
54
+ args = parser.parse_args()
55
+
56
+ # original sequences
57
+ # orig_pred = np.load(args.orig_pred_file)
58
+ orig_dev = pd.read_csv(args.orig_seq_file,sep='\t',header=0)
59
+ orig_dev.columns = ['sequence','label']
60
+ orig_dev['orig_seq'] = orig_dev['sequence'].apply(utils.kmer2seq)
61
+ orig_dev['idx'] = orig_dev.index
62
+
63
+ orig_pred = np.load(args.orig_pred_file)
64
+ orig_dev['orig_pred'] = orig_pred
65
+
66
+ # mutated sequences
67
+ # mut_pred = np.load(args.mut_pred_file)
68
+ mut_dev = pd.read_csv(args.mut_seq_file,sep='\t',header=0)
69
+ mut_dev.columns = ['sequence','label','idx'] #ignore label
70
+ mut_dev['mut_seq'] = mut_dev['sequence'].apply(utils.kmer2seq)
71
+
72
+ mut_pred = np.load(args.mut_pred_file)
73
+ mut_dev['mut_pred'] = mut_pred
74
+
75
+ # merge
76
+ dev = pd.merge(orig_dev[['idx','orig_seq','orig_pred']],
77
+ mut_dev[['idx','mut_seq','mut_pred']],
78
+ on='idx'
79
+ )
80
+ dev['diff'] = (dev['mut_pred'] - dev['orig_pred'])*(dev[['orig_pred','mut_pred']].max(axis=1))
81
+ dev['logOR'] = np.log2(dev['orig_pred']/(1-dev['orig_pred'])) - np.log2(dev['mut_pred']/(1-dev['mut_pred']))
82
+ dev.to_csv(os.path.join(args.save_file_dir,'mutations.tsv'),sep='\t')
83
+
84
+ if __name__ == "__main__":
85
+ main()
SNP/example_mut_file.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ 0 30 31 G
2
+ 23 52 53 T
3
+ 104 14 15 C
4
+ 125 22 23 A
5
+ 240 8 8 A
6
+ 325 10 11
SNP/examples/dev.tsv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ TTTTTA TTTTAA TTTAAA TTAAAA TAAAAG AAAAGT AAAGTA AAGTAA AGTAAA GTAAAC TAAACA AAACAC AACACT ACACTG CACTGT ACTGTT CTGTTT TGTTTT GTTTTC TTTTCA TTTCAT TTCATT TCATTA CATTAG ATTAGG TTAGGG TAGGGC AGGGCC GGGCCA GGCCAA GCCAAG CCAAGC CAAGCT AAGCTA AGCTAA GCTAAT CTAATC TAATCC AATCCT ATCCTT TCCTTA CCTTAT CTTATT TTATTG TATTGA ATTGAG TTGAGA TGAGAA GAGAAT AGAATT GAATTT AATTTC ATTTCT TTTCTA TTCTAA TCTAAA CTAAAG TAAAGG AAAGGG AAGGGA AGGGAC GGGACA GGACAT GACATT ACATTA 0
2
+ CGCATT GCATTA CATTAA ATTAAT TTAATA TAATAG AATAGT ATAGTG TAGTGG AGTGGA GTGGAC TGGACT GGACTA GACTAG ACTAGG CTAGGG TAGGGG AGGGGC GGGGCA GGGCAG GGCAGG GCAGGG CAGGGC AGGGCT GGGCTG GGCTGG GCTGGA CTGGAT TGGATT GGATTT GATTTT ATTTTC TTTTCG TTTCGG TTCGGA TCGGAG CGGAGG GGAGGC GAGGCA AGGCAG GGCAGT GCAGTG CAGTGT AGTGTG GTGTGC TGTGCA GTGCAG TGCAGT GCAGTT CAGTTC AGTTCC GTTCCC TTCCCA TCCCAA CCCAAT CCAATA CAATAA AATAAC ATAACT TAACTA AACTAG ACTAGT CTAGTT TAGTTC AGTTCC 23
3
+ TTCATA TCATAA CATAAA ATAAAT TAAATT AAATTA AATTAC ATTACC TTACCC TACCCC ACCCCG CCCCGT CCCGTT CCGTTT CGTTTC GTTTCT TTTCTC TTCTCA TCTCAT CTCATA TCATAG CATAGT ATAGTT TAGTTC AGTTCT GTTCTT TTCTTT TCTTTA CTTTAT TTTATA TTATAG TATAGC ATAGCA TAGCAG AGCAGT GCAGTG CAGTGT AGTGTG GTGTGA TGTGAA GTGAAA TGAAAA GAAAAC AAAACA AAACAG AACAGA ACAGAC CAGACT AGACTA GACTAA ACTAAT CTAATG TAATGG AATGGA ATGGAC TGGACC GGACCC GACCCT ACCCTT CCCTTC CCTTCT CTTCTG TTCTGG TCTGGT CTGGTT 104
4
+ GAGATA AGATAA GATAAA ATAAAG TAAAGG AAAGGA AAGGAA AGGAAG GGAAGG GAAGGG AAGGGA AGGGAA GGGAAT GGAATC GAATCA AATCAG ATCAGT TCAGTA CAGTAC AGTACC GTACCA TACCAT ACCATC CCATCC CATCCA ATCCAG TCCAGA CCAGAA CAGAAG AGAAGC GAAGCA AAGCAA AGCAAT GCAATG CAATGA AATGAG ATGAGA TGAGAT GAGATG AGATGG GATGGA ATGGAG TGGAGG GGAGGG GAGGGC AGGGCA GGGCAG GGCAGC GCAGCA CAGCAG AGCAGG GCAGGG CAGGGA AGGGAG GGGAGG GGAGGA GAGGAG AGGAGA GGAGAG GAGAGA AGAGAA GAGAAA AGAAAG GAAAGA AAAGAC 125
5
+ GGTACA GTACAA TACAAA ACAAAA CAAAAG AAAAGA AAAGAC AAGACG AGACGA GACGAA ACGAAC CGAACA GAACAA AACAAC ACAACG CAACGC AACGCC ACGCCA CGCCAT GCCATC CCATCC CATCCC ATCCCC TCCCCG CCCCGT CCCGTC CCGTCG CGTCGT GTCGTC TCGTCG CGTCGA GTCGAA TCGAAT CGAATG GAATGG AATGGC ATGGCA TGGCAG GGCAGA GCAGAC CAGACA AGACAA GACAAG ACAAGT CAAGTA AAGTAA AGTAAC GTAACC TAACCA AACCAG ACCAGT CCAGTC CAGTCT AGTCTT GTCTTT TCTTTG CTTTGT TTTGTA TTGTAA TGTAAC GTAACG TAACGT AACGTA ACGTAG CGTAGT GTAGTG 240
6
+ GGAACT GAACTT AACTTA ACTTAA CTTAAA TTAAAn TAAAna AAAnan AAnanG AnanGG nanGGC anGGCC nGGCCG GGCCGG GCCGGC CCGGCT CGGCTG GGCTGT GCTGTT CTGTTT TGTTTC GTTTCG TTTCGG TTCGGC TCGGCG CGGCGG GGCGGC GCGGCC CGGCCG GGCCGC GCCGCG CCGCGG CGCGGG GCGGGA CGGGAT GGGATG GGATGC GATGCC ATGCCC TGCCCC GCCCCT CCCCTG CCCTGC CCTGCG CTGCGC TGCGCT GCGCTG CGCTGA GCTGAC CTGACC TGACCG GACCGC ACCGCC CCGCCA CGCCAG GCCAGG CCAGGG CAGGGG AGGGGC GGGGCA GGGCAG GGCAGG GCAGGT CAGGTG AGGTGC GGTGCC GTGCCC 325
SNP/mutate_seqs.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### ::: mutate seqs ::: ####
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append('../motif')
6
+ import pandas as pd
7
+ import numpy as np
8
+ import argparse
9
+ import motif_utils as utils
10
+
11
+
12
+ def mutate(seq, start, end, target=None):
13
+ """
14
+ Mutate input sequence at specified position.
15
+
16
+ If target is not None, returns the mutated seq. Otherwise, returns a numpy array with shape (4,1)
17
+ with all four mutated possibilities.
18
+
19
+ Arguments:
20
+ seq -- str, original sequence.
21
+ start -- int, starting index where nucleotide needs to be changed. Counting starts at zero.
22
+ end -- int, ending index where nucleotide needs to be changed. Counting starts at zero.
23
+
24
+ Keyword arguments:
25
+ target -- str, the target nucleotide(s) to be changed to (default: None).
26
+
27
+ Returns:
28
+ mutated_seq -- str, mutated sequence.
29
+
30
+ """
31
+ assert end >= start and start >= 0 and end <= len(seq), "Wrong start and end index input."
32
+
33
+ if target is not None:
34
+ mutated_seq = seq[:start] + str(target) + seq[end:]
35
+ else:
36
+ mutated_seq = []
37
+ for n in ['A','T','G','C']:
38
+ m_seq = seq[:start] + str(n) + seq[end:]
39
+ mutated_seq.append(m_seq)
40
+ mutated_seq = np.asarray(mutated_seq)
41
+ return mutated_seq
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "seq_file",
47
+ type=str,
48
+ help="Path to input sequence+label .tsv file.",
49
+ )
50
+
51
+ parser.add_argument(
52
+ "save_file_dir",
53
+ type=str,
54
+ help="Path to save the mutated seqs",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--mut_file",
59
+ default=None,
60
+ type=str,
61
+ help="Path to the file defining how each input seq should be mutated",
62
+ )
63
+
64
+ parser.add_argument(
65
+ "--k",
66
+ default=3,
67
+ type=int,
68
+ help="length of kmer for conversion of mutated seqs"
69
+ )
70
+
71
+ # TODO: add the conditions
72
+ args = parser.parse_args()
73
+
74
+ os.makedirs(args.save_file_dir, exist_ok=True)
75
+
76
+ mutated_dev = {'index':[],'seq':[]}
77
+
78
+ dev = pd.read_csv(args.seq_file,sep='\t',header=0)
79
+ dev.columns = ['sequence','label']
80
+ dev['seq'] = dev['sequence'].apply(utils.kmer2seq)
81
+
82
+ if args.mut_file is not None:
83
+ mut_file = pd.read_csv(args.mut_file, sep='\t',header=None)
84
+ mut_file = mut_file.fillna('')
85
+ mut_file.columns = ['idx','start', 'end', 'allele']
86
+ mut_file['idx'] = mut_file['idx'].astype(int)
87
+ mut_file['start'] = mut_file['start'].astype(int)
88
+ mut_file['end'] = mut_file['end'].astype(int)
89
+ dev_selected = dev.iloc[mut_file['idx'].tolist(),:].reset_index()
90
+ for i, row in dev_selected.iterrows():
91
+ seq = row['seq']
92
+ mut = mut_file.iloc[i]
93
+ mut_seq = mutate(seq, mut['start'], mut['end'], target = mut['allele'])
94
+ mut_seq = utils.seq2kmer(mut_seq, args.k)
95
+ mutated_dev['index'].append(mut['idx'])
96
+ mutated_dev['seq'].append(mut_seq)
97
+ else:
98
+ for i, row in dev.iterrows():
99
+ seq = row['seq']
100
+ for j in range(len(seq)):
101
+ mut_seq = mutate(seq, j, j+1)
102
+ mut_seq = [utils.seq2kmer(seq, args.k) for seq in mut_seq]
103
+ idx = [i] * 4
104
+ mutated_dev['index'].extend(idx)
105
+ mutated_dev['seq'].extend(mut_seq)
106
+
107
+ mutated_dev = pd.DataFrame.from_dict(mutated_dev)
108
+ mutated_dev = mutated_dev[['seq','index']]
109
+ mutated_dev.columns = ['sequence','index']
110
+ mutated_dev['label'] = 0
111
+ mutated_dev.iloc[0, mutated_dev.columns.get_loc('label')] = 1
112
+ mutated_dev = mutated_dev[['sequence','label','index']]
113
+
114
+ mutated_dev.to_csv(os.path.join(args.save_file_dir,'dev.tsv'),sep='\t',header=True, index=False)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()
examples/.Rhistory ADDED
File without changes
examples/.run_pretrain.py.swp ADDED
Binary file (1.02 kB). View file
 
examples/6mer_pretrain_emb/static_6mer_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5422f25436f65a3cb50f5e3881ab1a4c0e3d417eb8fb11f485fc1f9b0ef0b04d
3
+ size 12598400
examples/6mer_pretrain_emb_20ways/static_6mer_embed_20ways.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e621f2367d58715c3defef6e0a504feed12e96a308da56f19383e68534e6b03
3
+ size 12598400
examples/6mer_pretrain_emb_adaptive/static_adaptive_embed.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41de47985ee1cd6d29a98951beece1d79d7c48e6295e7701e7bfb46f06079705
3
+ size 12598400
examples/compute_result.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import csv
4
+ from copy import deepcopy
5
+ from sklearn.metrics import matthews_corrcoef, confusion_matrix, f1_score
6
+
7
+ def generate_pred(predict_results, i, slide, metric="max"):
8
+
9
+ results = predict_results[i*3:(i+1)*3]
10
+
11
+ if metric == "max":
12
+ pred = max(results)
13
+ elif metric == "mean":
14
+ pred = np.mean(results)
15
+ elif metric == "second-max":
16
+ pred = np.sort(results)[-2]
17
+ else:
18
+ pass
19
+
20
+ return pred
21
+
22
+ def Compute_scan(args):
23
+ predict_results = np.load(args.pred_path)
24
+ labels = np.load(args.label_path)
25
+ labels = list(labels.astype(int))
26
+
27
+ results = []
28
+ for i in range(len(labels)):
29
+ pred = generate_pred(predict_results, i, args.slide, args.metric)
30
+
31
+ if pred >= args.bound:
32
+ results.append(1)
33
+ else:
34
+ results.append(0)
35
+ a = set(results)
36
+ b = set(labels)
37
+ f1 = f1_score(y_true=labels, y_pred=results)
38
+ mcc = matthews_corrcoef(labels, results)
39
+ tn, fp, fn, tp = confusion_matrix(labels, results).ravel()
40
+
41
+ count = 0
42
+ for i in range(len(results)):
43
+ if results[i] == labels[i]:
44
+ count+=1
45
+
46
+ print("number of examples: " + str(len(labels)))
47
+ print("number of positive examples: " + str(sum(labels)))
48
+ print("number of negative examples: " + str(len(labels)-sum(labels)))
49
+ print("f1: ", str(f1))
50
+ print("mcc: " + str(mcc))
51
+ print("accuracy: " + str(float(count)/len(results)))
52
+ print("tn:" + str(tn))
53
+ print("fp:" + str(fp))
54
+ print("fn:" + str(fn))
55
+ print("tp:" + str(tp))
56
+
57
+
58
+ def Compute_mouse(args):
59
+ result_file = open(args.pred_path, "r")
60
+ results = result_file.readlines()
61
+ print(len(results))
62
+
63
+ all_preds = []
64
+ current_preds = []
65
+ for result in results:
66
+ scores = result.split()
67
+ scores = [scores[0], float(scores[1]), float(scores[2]), float(scores[3]), float(scores[4]), float(scores[5]), float(scores[6]), float(scores[7])]
68
+ if current_preds == [] or scores[0] == current_preds[0][0]:
69
+ current_preds.append(scores)
70
+ else:
71
+ all_preds.append(current_preds)
72
+ current_preds = []
73
+ current_preds.append(scores)
74
+ all_preds.append(current_preds)
75
+
76
+ print("Number of task: %d" % len(all_preds))
77
+
78
+ def get_acc(val):
79
+ return val[1]
80
+
81
+ def get_auc(val):
82
+ return val[2]
83
+
84
+ tasks = []
85
+ acc = []
86
+ auc = []
87
+ aupr = []
88
+ f1 = []
89
+ mcc = []
90
+ precision = []
91
+ recall = []
92
+
93
+ for pred in all_preds:
94
+ if len(pred) < 10 :
95
+ print("Short %s : %d" % (pred[0][0], len(pred)))
96
+
97
+ if args.index == "acc":
98
+ pred.sort(key=get_acc)
99
+ elif args.index == "auc":
100
+ pred.sort(key=get_auc)
101
+ else:
102
+ raise ValueError()
103
+
104
+ BEST = -1
105
+ for i in range(len(pred)):
106
+ if pred[i][1] == pred[-1][1] and pred[i][2] > pred[-1][2]:
107
+ BEST = deepcopy(i)
108
+ tasks.append(pred[0][0])
109
+
110
+ best_pred = pred[BEST]
111
+ acc.append(best_pred[1])
112
+ auc.append(best_pred[2])
113
+ aupr.append(best_pred[3])
114
+ f1.append(best_pred[4])
115
+ mcc.append(best_pred[5])
116
+ precision.append(best_pred[6])
117
+ recall.append(best_pred[7])
118
+
119
+ acc_ave = np.mean(acc)
120
+ auc_ave = np.mean(auc)
121
+ aupr_ave = np.mean(aupr)
122
+ f1_ave = np.mean(f1)
123
+ mcc_ave = np.mean(mcc)
124
+ precision_ave = np.mean(precision)
125
+ recall_ave = np.mean(recall)
126
+
127
+
128
+ print("acc: " + str(acc_ave))
129
+ print("auc: " + str(auc_ave))
130
+ print("aupr: " + str(aupr_ave))
131
+ print("f1: ", str(f1_ave))
132
+ print("mcc: " + str(mcc_ave))
133
+ print("precision: ", str(precision_ave))
134
+ print("recall: " + str(recall_ave))
135
+
136
+ # find and print the tasks whose results are worst
137
+ ranks = np.argsort(auc)[:args.num_worst]
138
+ print("Top %d worst tasks: " % (args.num_worst))
139
+ for i in ranks:
140
+ print(tasks[i] + " %3f %3f" % (acc[i], auc[i]))
141
+
142
+
143
+
144
+
145
+ def Compute_690(args):
146
+ result_file = open(args.pred_path, "r")
147
+ results = result_file.readlines()
148
+
149
+ preds = []
150
+
151
+ for result in results:
152
+ scores = result.split()
153
+ preds.append([scores[0], float(scores[1]), float(scores[2]), float(scores[4]), float(scores[5])])
154
+
155
+ num_results = args.num_results
156
+
157
+ num_example = int(len(preds)/num_results)
158
+ print("Num of tasks: %d" % num_example)
159
+
160
+ def get_acc(val):
161
+ return val[1]
162
+
163
+ def get_auc(val):
164
+ return val[2]
165
+
166
+ def get_f1(val):
167
+ return val[3]
168
+
169
+ def get_mcc(val):
170
+ return val[4]
171
+
172
+ tasks = []
173
+ acc = []
174
+ auc = []
175
+ f1 = []
176
+ mcc = []
177
+
178
+ for i in range(num_example):
179
+ tasks.append(preds[i*num_results][0])
180
+
181
+ current_preds = preds[i*num_results:(i+1)*num_results]
182
+ if args.index == "acc":
183
+ current_preds.sort(key=get_acc)
184
+ elif args.index == "auc":
185
+ current_preds.sort(key=get_auc)
186
+ elif args.index == "f1":
187
+ current_preds.sort(key=get_f1)
188
+ elif args.index == "mcc":
189
+ current_preds.sort(key=get_mcc)
190
+ else:
191
+ raise ValueError()
192
+ best_pred = current_preds[-1]
193
+ acc.append(best_pred[1])
194
+ auc.append(best_pred[2])
195
+ f1.append(best_pred[3])
196
+ mcc.append(best_pred[4])
197
+
198
+ # calculate and print the average scores
199
+ acc_ave = np.mean(acc)
200
+ auc_ave = np.mean(auc)
201
+ f1_ave = np.mean(f1)
202
+ mcc_ave = np.mean(mcc)
203
+
204
+
205
+ print("acc: " + str(acc_ave))
206
+ print("auc: " + str(auc_ave))
207
+ print("f1: ", str(f1_ave))
208
+ print("mcc: " + str(mcc_ave))
209
+
210
+ # find and print the tasks whose results are worst
211
+ ranks = np.argsort(auc)[:args.num_worst]
212
+ print("Top %d worst tasks: " % (args.num_worst))
213
+ for i in ranks:
214
+ print(tasks[i] + " %3f %3f" % (acc[i], auc[i]))
215
+
216
+
217
+
218
+ def main():
219
+ parser = argparse.ArgumentParser()
220
+ parser.add_argument(
221
+ "--bound",
222
+ default=0.5,
223
+ type=float,
224
+ help="K-mer",
225
+ )
226
+ parser.add_argument(
227
+ "--pred_path",
228
+ default=None,
229
+ type=str,
230
+ help="The path of the predicted result",
231
+ )
232
+ parser.add_argument(
233
+ "--label_path",
234
+ default=None,
235
+ type=str,
236
+ help="The path of the label",
237
+ )
238
+ parser.add_argument(
239
+ "--metric",
240
+ default="max",
241
+ type=str,
242
+ help="The metric of computing predited result (scan)",
243
+ )
244
+ parser.add_argument(
245
+ "--slide",
246
+ default=3,
247
+ type=int,
248
+ help="How many 500s to use for the predictes result of 1000 (scan)",
249
+ )
250
+ parser.add_argument(
251
+ "--task",
252
+ default="scan",
253
+ type=str,
254
+ help="Which task to compute result",
255
+ )
256
+ parser.add_argument(
257
+ "--index",
258
+ default="acc",
259
+ type=str,
260
+ help="Which index to sort result (690)",
261
+ )
262
+ parser.add_argument(
263
+ "--num_results",
264
+ default="10",
265
+ type=int,
266
+ help="Number of results for each task (690)",
267
+ )
268
+ parser.add_argument(
269
+ "--num_worst",
270
+ default="10",
271
+ type=int,
272
+ help="Number of worst tasks to print out (690)",
273
+ )
274
+
275
+ args = parser.parse_args()
276
+
277
+ if args.task == "scan":
278
+ Compute_scan(args)
279
+ elif args.task == "690":
280
+ Compute_690(args)
281
+ elif args.task == "mouse":
282
+ Compute_mouse(args)
283
+ else:
284
+ raise ValueError()
285
+
286
+
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
examples/data_process_template/.process_pretrain_data_multi.py.swp ADDED
Binary file (4.1 kB). View file
 
examples/data_process_template/process_690.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import os
4
+ import numpy as np
5
+ import random
6
+ from process_pretrain_data import get_kmer_sentence
7
+
8
+
9
+
10
+ def Process(args):
11
+ path = args.file_path
12
+ all_folders = os.listdir(path)
13
+
14
+ count = 0
15
+
16
+ for folder in all_folders:
17
+ # load data
18
+ train_seq_path = os.path.join(args.file_path, folder, "train", "sequences_alph.npy")
19
+ test_seq_path = os.path.join(args.file_path, folder, "test", "sequences_alph.npy")
20
+ train_lab_path = os.path.join(args.file_path, folder, "train", "targets.npy")
21
+ test_lab_path = os.path.join(args.file_path, folder, "test", "targets.npy")
22
+ train_sequences = np.load(train_seq_path)
23
+ test_sequences = np.load(test_seq_path)
24
+ train_labels = np.load(train_lab_path)
25
+ test_labels = np.load(test_lab_path)
26
+
27
+ train_sequences = train_sequences.reshape(train_sequences.shape[0],1)
28
+ test_sequences = test_sequences.reshape(test_sequences.shape[0],1)
29
+ train_labels = train_labels.reshape(train_labels.shape[0],1)
30
+ test_labels = test_labels.reshape(test_labels.shape[0],1)
31
+
32
+ # concat sequence and labels together
33
+ trains = list(np.concatenate((train_sequences, train_labels), axis=1))
34
+ tests = list(np.concatenate((test_sequences, test_labels), axis=1))
35
+
36
+ random.seed(24)
37
+ random.shuffle(trains)
38
+ random.shuffle(trains)
39
+ random.shuffle(tests)
40
+ random.shuffle(tests)
41
+
42
+
43
+ # make output path
44
+ output_path = os.path.join(args.output_path, str(args.kmer), folder)
45
+ if not os.path.exists(output_path):
46
+ os.makedirs(output_path)
47
+
48
+
49
+
50
+ # write files
51
+ f_train = open(os.path.join(output_path, "train.tsv"), 'wt')
52
+ tsv_train = csv.writer(f_train, delimiter='\t')
53
+ tsv_train.writerow(["sequence", "label"])
54
+ for i in range(len(trains)):
55
+ sentence = get_kmer_sentence(trains[i][0].decode("utf-8"), args.kmer)
56
+ tsv_train.writerow([sentence, int(trains[i][1])])
57
+
58
+ f_dev = open(os.path.join(output_path, "dev.tsv"), 'wt')
59
+ tsv_dev = csv.writer(f_dev, delimiter='\t')
60
+ tsv_dev.writerow(["sequence", "label"])
61
+ for i in range(len(tests)):
62
+ sentence = get_kmer_sentence(tests[i][0].decode("utf-8"), args.kmer)
63
+ tsv_dev.writerow([sentence, int(tests[i][1])])
64
+
65
+
66
+ count += 1
67
+ print("Finish %s folders" % (count))
68
+
69
+
70
+
71
+
72
+
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser()
76
+ parser.add_argument(
77
+ "--kmer",
78
+ default=1,
79
+ type=int,
80
+ help="K-mer",
81
+ )
82
+ parser.add_argument(
83
+ "--file_path",
84
+ default=None,
85
+ type=str,
86
+ help="The path of the file to be processed",
87
+ )
88
+ parser.add_argument(
89
+ "--output_path",
90
+ default=None,
91
+ type=str,
92
+ help="The path of the processed data",
93
+ )
94
+ args = parser.parse_args()
95
+
96
+ Process(args)
97
+
98
+
99
+
100
+
101
+
102
+ if __name__ == "__main__":
103
+ main()
examples/data_process_template/process_csv.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import json
4
+ import argparse
5
+ import random
6
+ from process_pretrain_data import get_kmer_sentence
7
+
8
+
9
+ max_length = 0
10
+
11
+ def Process_pair(args):
12
+ random.seed(42)
13
+
14
+ root_path = args.file_path.split('/')[-1]
15
+ train_seq1_file = open(args.file_path+"/"+root_path+"_enhancer.fasta", "r")
16
+ train_seq2_file = open(args.file_path+"/"+root_path+"_promoter.fasta", "r")
17
+ train_label_file = open(args.file_path+"/"+root_path+"_label.txt", "r")
18
+ test_seq1_file = open(args.file_path+"/"+root_path+"_enhancer_test.fasta", "r")
19
+ test_seq2_file = open(args.file_path+"/"+root_path+"_promoter_test.fasta", "r")
20
+ test_label_file = open(args.file_path+"/"+root_path+"_label_test.txt", "r")
21
+
22
+ train_seq1 = train_seq1_file.readlines()
23
+ train_seq2 = train_seq2_file.readlines()
24
+ train_label = train_label_file.readlines()
25
+ test_seq1 = test_seq1_file.readlines()
26
+ test_seq2 = test_seq2_file.readlines()
27
+ test_label = test_label_file.readlines()
28
+
29
+ train_lines = []
30
+ test_lines = []
31
+ for i in range(len(train_label)):
32
+ train_lines.append([train_seq1[2*i+1], train_seq2[2*i+1], train_label[i]])
33
+ for i in range(len(test_label)):
34
+ test_lines.append([test_seq1[2*i+1], test_seq2[2*i+1], test_label[i]])
35
+
36
+ random.shuffle(train_lines)
37
+
38
+ if args.dev:
39
+ num_dev = int(len(train_lines)/10)
40
+ dev_lines = train_lines[:num_dev]
41
+ train_lines = train_lines[num_dev:]
42
+
43
+ output_path = make_path(args)
44
+
45
+ suffix = '.csv' if args.csv else '.tsv'
46
+ delimiter = ',' if args.csv else '\t'
47
+
48
+ f_train = open(os.path.join(output_path, "train" + suffix), 'wt')
49
+ train_w = csv.writer(f_train, delimiter=delimiter)
50
+ train_w.writerow(["seq1", "seq2", "label"])
51
+ if args.dev:
52
+ f_dev = open(os.path.join(output_path, "dev" + suffix), 'wt')
53
+ dev_w = csv.writer(f_dev, delimiter=delimiter)
54
+ dev_w.writerow(["seq1", "seq2", "label"])
55
+ os.makedirs(os.path.join(output_path, "test"))
56
+ f_test = open(os.path.join(output_path, "test", "dev" + suffix), 'wt')
57
+ test_w = csv.writer(f_test, delimiter=delimiter)
58
+ test_w.writerow(["seq1", "seq2", "label"])
59
+ else:
60
+ f_test = open(os.path.join(output_path, "dev" + suffix), 'wt')
61
+ test_w = csv.writer(f_test, delimiter=delimiter)
62
+ test_w.writerow(["seq1", "seq2", "label"])
63
+
64
+ def write_file_pair(lines, writer, seq1_index=0, seq2_index=1, label_index=2):
65
+ for line in lines:
66
+ seq1 = get_kmer_sentence(line[seq1_index], kmer=args.kmer, stride=args.stride)
67
+ seq2 = get_kmer_sentence(line[seq2_index], kmer=args.kmer, stride=args.stride)
68
+ writer.writerow([seq1, seq2, str(int(line[label_index]))])
69
+
70
+ write_file_pair(train_lines, train_w)
71
+ write_file_pair(test_lines, test_w)
72
+
73
+ if args.dev:
74
+ write_file_pair(dev_lines, dev_w)
75
+
76
+
77
+ def make_path(args):
78
+ output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(args.kmer))
79
+ if not os.path.exists(output_path):
80
+ os.makedirs(output_path)
81
+ return output_path
82
+
83
+ def write_file(lines, writer, seq_index=2, label_index=3, kmer=6, stride=1):
84
+ global max_length
85
+ for line in lines:
86
+ sentence = get_kmer_sentence(line[seq_index], kmer=kmer, stride=stride)
87
+ if len(sentence.split()) > max_length:
88
+ max_length = len(sentence.split())
89
+ if label_index == -100:
90
+ writer.writerow([sentence, str(0)])
91
+ else:
92
+ writer.writerow([sentence, str(line[label_index])])
93
+
94
+ def Process(args):
95
+ random.seed(24)
96
+
97
+ train = os.path.join(args.file_path, "train.csv")
98
+ test = os.path.join(args.file_path, "test.csv")
99
+ train_file = open(train, "r", encoding="utf-8-sig")
100
+ test_file = open(test, "r", encoding="utf-8-sig")
101
+
102
+ train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
103
+ test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
104
+
105
+ random.shuffle(train_lines)
106
+ random.shuffle(test_lines)
107
+
108
+ if args.dev:
109
+ num_dev = int(len(train_lines)/9)
110
+ dev_lines = train_lines[:num_dev]
111
+ train_lines = train_lines[num_dev:]
112
+
113
+ print(train_lines[0])
114
+
115
+ output_path = make_path(args)
116
+
117
+ suffix = '.csv' if args.csv else '.tsv'
118
+ delimiter = ',' if args.csv else '\t'
119
+
120
+
121
+ f_train = open(os.path.join(output_path, "train"+suffix), 'wt')
122
+ train_w = csv.writer(f_train, delimiter=delimiter)
123
+ train_w.writerow(["sentence", "label"])
124
+ if args.dev:
125
+ f_dev = open(os.path.join(output_path, "dev"+suffix), 'wt')
126
+ dev_w = csv.writer(f_dev, delimiter=delimiter)
127
+ dev_w.writerow(["sentence", "label"])
128
+ f_test = open(os.path.join(output_path, "test"+suffix), 'wt')
129
+ test_w = csv.writer(f_test, delimiter=delimiter)
130
+ test_w.writerow(["sentence", "label"])
131
+ else:
132
+ f_test = open(os.path.join(output_path, "dev"+suffix), 'wt')
133
+ test_w = csv.writer(f_test, delimiter=delimiter)
134
+ test_w.writerow(["sentence", "label"])
135
+
136
+
137
+ write_file(train_lines, train_w, args.seq_index, args.label_index)
138
+ write_file(test_lines, test_w, args.seq_index, args.label_index)
139
+
140
+ if args.dev:
141
+ write_file(dev_lines, dev_w)
142
+
143
+
144
+ print("max length: %d" % (max_length))
145
+
146
+
147
+ def Process_UCE(args):
148
+ len_count = {}
149
+
150
+ line2index = {}
151
+
152
+ pred_file = open(args.file_path, "r", encoding="utf-8-sig")
153
+ pred_lines = list(csv.reader(pred_file, delimiter=",", quotechar=None))[1:]
154
+
155
+ suffix = '.csv' if args.csv else '.tsv'
156
+ delimiter = ',' if args.csv else '\t'
157
+
158
+ f_pred = open(os.path.join(args.output_path, "dev"+suffix), 'wt')
159
+ pred_w = csv.writer(f_pred, delimiter=delimiter)
160
+ pred_w.writerow(["sentence", "label"])
161
+
162
+ index = 1
163
+ line_num = 0
164
+ for line in pred_lines:
165
+ len_count[len(line[8])] = len_count.get(len(line[8]), 0) + 1
166
+ len_count[len(line[-2])] = len_count.get(len(line[-2]), 0) + 1
167
+
168
+ cur_index = [index, index+1]
169
+ ref = get_kmer_sentence(line[8], args.kmer, args.stride)
170
+ pred_w.writerow([ref, 0])
171
+
172
+ mut1 = get_kmer_sentence(line[-2], args.kmer, args.stride)
173
+ pred_w.writerow([mut1, 0])
174
+
175
+ index += 2
176
+
177
+ if line[-2] != line[-1]:
178
+ len_count[len(line[-1])] = len_count.get(len(line[-1]), 0) + 1
179
+ mut2 = get_kmer_sentence(line[-1], args.kmer, args.stride)
180
+ pred_w.writerow([mut2, 0])
181
+ cur_index.append(index)
182
+ index += 1
183
+
184
+ line2index[line_num] = cur_index
185
+ line_num += 1
186
+
187
+ with open(os.path.join(args.output_path, "line2index.json"), "w") as f:
188
+ json.dump(line2index, f)
189
+ with open(os.path.join(args.output_path, "lencount.json"), "w") as f:
190
+ json.dump(len_count, f)
191
+
192
+
193
+ def Process_Virus(args):
194
+ file_path = args.file_path
195
+
196
+ all_files = os.listdir(file_path)
197
+ all_files = [f for f in all_files if not f.startswith("unclass")]
198
+ all_lines = []
199
+ for i, f in enumerate(all_files):
200
+ f_dir = os.path.join(file_path, f)
201
+ cur_file = open(f_dir, "r", encoding="utf-8-sig")
202
+ cur_lines = list(csv.reader(cur_file, delimiter=",", quotechar=None))[1:]
203
+ all_lines.extend(cur_lines)
204
+
205
+
206
+ suffix = '.csv' if args.csv else '.tsv'
207
+ delimiter = ',' if args.csv else '\t'
208
+
209
+ f_pred = open(os.path.join(args.output_path, "dev"+suffix), 'wt')
210
+ pred_w = csv.writer(f_pred, delimiter=delimiter)
211
+ pred_w.writerow(["sentence", "label"])
212
+
213
+ index = 1
214
+ line_num = 0
215
+ for line in pred_lines:
216
+ cur_index = [index, index+1]
217
+ ref = get_kmer_sentence(line[8], args.kmer, args.stride)
218
+ pred_w.writerow([ref, 0])
219
+
220
+ mut1 = get_kmer_sentence(line[-2], args.kmer, args.stride)
221
+ pred_w.writerow([mut1, 0])
222
+
223
+ index += 2
224
+
225
+ if line[-2] != line[-1]:
226
+ len_count[len(line[-1])] = len_count.get(len(line[-1]), 0) + 1
227
+ mut2 = get_kmer_sentence(line[-1], args.kmer, args.stride)
228
+ pred_w.writerow([mut2, 0])
229
+ cur_index.append(index)
230
+ index += 1
231
+
232
+ line2index[line_num] = cur_index
233
+ line_num += 1
234
+
235
+ with open(os.path.join(args.output_path, "line2index.json"), "w") as f:
236
+ json.dump(line2index, f)
237
+ with open(os.path.join(args.output_path, "lencount.json"), "w") as f:
238
+
239
+
240
+
241
+
242
+ def main():
243
+ parser = argparse.ArgumentParser()
244
+ parser.add_argument(
245
+ "--kmer",
246
+ default=1,
247
+ type=int,
248
+ help="K-mer",
249
+ )
250
+ parser.add_argument(
251
+ "--stride",
252
+ default=1,
253
+ type=int,
254
+ help="stride in getting kmer sequence",
255
+ )
256
+ parser.add_argument(
257
+ "--file_path",
258
+ default=None,
259
+ type=str,
260
+ help="The path of the file to be processed",
261
+ )
262
+ parser.add_argument(
263
+ "--output_path",
264
+ default=None,
265
+ type=str,
266
+ help="The path of the processed data",
267
+ )
268
+ parser.add_argument(
269
+ "--dev",
270
+ action="store_true",
271
+ help="Use this flag to split data as (8:1:1), else (9:1)",
272
+ )
273
+ parser.add_argument(
274
+ "--csv",
275
+ action="store_true",
276
+ help="if output csv file or not, if not, output tsv",
277
+ )
278
+ parser.add_argument(
279
+ "--pair",
280
+ action="store_true",
281
+ help="Use this flag to split data as (8:1:1), else (9:1)",
282
+ )
283
+ parser.add_argument(
284
+ "--uce",
285
+ action="store_true",
286
+ help="Use this flag to split data as (8:1:1), else (9:1)",
287
+ )
288
+ parser.add_argument(
289
+ "--seq_index",
290
+ default=2,
291
+ type=int,
292
+ help="index of seq in the original csv file",
293
+ )
294
+ parser.add_argument(
295
+ "--label_index",
296
+ default=3,
297
+ type=int,
298
+ help="index of label in the original csv file",
299
+ )
300
+ args = parser.parse_args()
301
+
302
+ if args.pair:
303
+ Process_pair(args)
304
+ elif args.uce:
305
+ Process_UCE(args)
306
+ else:
307
+ Process(args)
308
+
309
+
310
+ if __name__ == "__main__":
311
+ main()
examples/data_process_template/process_finetune_data.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import os
4
+ import random
5
+ import numpy as np
6
+ from process_pretrain_data import get_kmer_sentence
7
+
8
+ max_length = 0
9
+
10
+ def write_file(lines, path, kmer, head=True, seq_index=0, label_index=1):
11
+ with open(path, 'wt') as f:
12
+ tsv_w = csv.writer(f, delimiter='\t')
13
+ if head:
14
+ tsv_w.writerow(["setence", "label"])
15
+ for line in lines:
16
+ if kmer == 0:
17
+ sentence = str(line[seq_index])
18
+ else:
19
+ sentence = str(get_kmer_sentence("".join(line[seq_index].split()), kmer))
20
+ if label_index == None:
21
+ label = "0"
22
+ else:
23
+ label = str(line[label_index])
24
+ tsv_w.writerow([sentence, label])
25
+
26
+
27
+ def Shuffle(args):
28
+ old_file = open(args.file_path, "r", encoding="utf-8-sig")
29
+ old_lines = list(csv.reader(old_file, delimiter="\t", quotechar=None))[1:]
30
+ random.shuffle(old_lines)
31
+
32
+ write_file(old_lines, args.file_path, 0)
33
+
34
+ def Find_train(args):
35
+ random.seed(args.seed)
36
+
37
+ tata = args.file_path + "/TATA_249to50.tsv"
38
+ notata = args.file_path + "/noTATA_249to50.tsv"
39
+ tata_file = open(tata, "r", encoding="utf-8-sig")
40
+ notata_file = open(notata, "r", encoding="utf-8-sig")
41
+ tata_lines = list(csv.reader(tata_file, delimiter="\t", quotechar=None))[1:]
42
+ notata_lines = list(csv.reader(notata_file, delimiter="\t", quotechar=None))[1:]
43
+
44
+ tata_test = args.file_path + "/tata_test.tsv"
45
+ notata_test = args.file_path + "/notata_test.tsv"
46
+ tata_test_file = open(tata_test, "r", encoding="utf-8-sig")
47
+ notata_test_file = open(notata_test, "r", encoding="utf-8-sig")
48
+ tata_test_lines = list(csv.reader(tata_test_file, delimiter="\t", quotechar=None))[1:]
49
+ notata_test_lines = list(csv.reader(notata_test_file, delimiter="\t", quotechar=None))[1:]
50
+
51
+
52
+ train_lines = []
53
+
54
+ for line in tata_lines:
55
+ if [line[0], line[1]] not in tata_test_lines:
56
+ train_lines.append([line[0], line[1]])
57
+
58
+
59
+ for line in notata_lines:
60
+ if [line[0], line[1]] not in notata_test_lines:
61
+ train_lines.append([line[0], line[1]])
62
+
63
+ random.shuffle(train_lines)
64
+ random.shuffle(train_lines)
65
+
66
+ # num_dev = int(len(train_lines)/9.0)
67
+ # dev_lines = train_lines[:num_dev]
68
+ # train_lines = train_lines[num_dev:]
69
+
70
+
71
+ write_file(train_lines, args.file_path+"/train.tsv", args.kmer, head=False)
72
+ # write_file(dev_lines, args.file_path+"/dev.tsv", args.kmer)
73
+
74
+ for kmer in range(3,7):
75
+ root_path = os.path.join(args.file_path, str(kmer))
76
+ if not os.path.exists(root_path):
77
+ os.makedirs(root_path)
78
+
79
+ train_file = open(os.path.join(args.file_path,"train.tsv"), "r", encoding="utf-8-sig")
80
+ lines = list(csv.reader(train_file, delimiter="\t", quotechar=None))
81
+ train_path = os.path.join(root_path,"train.tsv")
82
+
83
+ write_file(lines, train_path, kmer)
84
+
85
+ tata_path = os.path.join(root_path, "tata")
86
+ notata_path = os.path.join(root_path, "notata")
87
+ os.makedirs(tata_path)
88
+ os.makedirs(notata_path)
89
+
90
+ dev_lines = tata_test_lines+notata_test_lines
91
+ dev_path = os.path.join(root_path,"dev.tsv")
92
+
93
+ write_file(tata_test_lines, os.path.join(tata_path, "dev.tsv"), kmer)
94
+ write_file(notata_test_lines, os.path.join(notata_path, "dev.tsv"), kmer)
95
+ write_file(dev_lines, dev_path, kmer)
96
+
97
+ def Process_1000(args):
98
+ random.seed(args.seed)
99
+
100
+ tata_train = args.file_path + "TATA_scan_train.csv"
101
+ notata_train = args.file_path + "noTATA_scan_train.csv"
102
+ tata_train_file = open(tata_train, "r", encoding="utf-8-sig")
103
+ notata_train_file = open(notata_train, "r", encoding="utf-8-sig")
104
+ tata_train_lines = list(csv.reader(tata_train_file, delimiter=",", quotechar=None))[1:]
105
+ notata_train_lines = list(csv.reader(notata_train_file, delimiter=",", quotechar=None))[1:]
106
+
107
+ tata_test = args.file_path + "/TATA_scan_test.csv"
108
+ notata_test = args.file_path + "/noTATA_scan_test.csv"
109
+ tata_test_file = open(tata_test, "r", encoding="utf-8-sig")
110
+ notata_test_file = open(notata_test, "r", encoding="utf-8-sig")
111
+ tata_test_lines = list(csv.reader(tata_test_file, delimiter=",", quotechar=None))[1:]
112
+ notata_test_lines = list(csv.reader(notata_test_file, delimiter=",", quotechar=None))[1:]
113
+
114
+
115
+ print("Original:")
116
+ print("tata train: %d" % (len(tata_train_lines)))
117
+ print("notata train: %d" % (len(notata_train_lines)))
118
+ print("tata test: %d" % (len(tata_test_lines)))
119
+ print("tata test: %d" % (len(notata_test_lines)))
120
+
121
+ random.shuffle(tata_train_lines)
122
+ random.shuffle(notata_train_lines)
123
+ random.shuffle(tata_test_lines)
124
+ random.shuffle(notata_test_lines)
125
+
126
+
127
+ notata_train_lines = notata_train_lines[:len(tata_train_lines)]
128
+ notata_test_lines = notata_test_lines[:len(tata_test_lines)]
129
+ with open(os.path.join(args.file_path, "notata_test_id"), "w") as f:
130
+ tsv_w = csv.writer(f, delimiter=',')
131
+ tsv_w.writerow(["index", "chrom", "start", "end", "name", "strand", "keys", "id"])
132
+ for line in notata_test_lines:
133
+ tsv_w.writerow([line[0], line[1], line[2], line[3], line[4], line[5], line[7], line[9]])
134
+
135
+
136
+
137
+ # print("After:")
138
+ # print("tata train: %d" % (len(tata_train_lines)))
139
+ # print("notata train: %d" % (len(notata_train_lines)))
140
+ # print("tata test: %d" % (len(tata_test_lines)))
141
+ # print("tata test: %d" % (len(notata_test_lines)))
142
+
143
+ # train_lines = tata_train_lines + notata_train_lines
144
+ # test_lines = tata_test_lines + notata_test_lines
145
+
146
+
147
+ # output_path = args.output_path if args.output_path is not None else args.file_path
148
+
149
+ # write_file(test_lines, output_path+"/dev.tsv", args.kmer, head=False, seq_index=8, label_index=6)
150
+ # write_file(train_lines, output_path+"/train.tsv", args.kmer, head=False, seq_index=8, label_index=6)
151
+ # write_file(tata_test_lines, output_path+"/tata_dev.tsv", args.kmer, head=False, seq_index=8, label_index=6)
152
+ # write_file(tata_train_lines, output_path+"/tata_train.tsv", args.kmer, head=False, seq_index=8, label_index=6)
153
+ # write_file(notata_test_lines, output_path+"/notata_dev.tsv", args.kmer, head=False, seq_index=8, label_index=6)
154
+ # write_file(notata_train_lines, output_path+"/notata_train.tsv", args.kmer, head=False, seq_index=8, label_index=6)
155
+
156
+ # Process_1000_kmer(args, test_lines, train_lines, tata_test_lines, tata_train_lines, notata_test_lines, notata_train_lines)
157
+
158
+
159
+ def Process_1000_kmer(args, test_lines=None, train_lines=None, tata_test_lines=None, tata_train_lines=None, notata_test_lines=None, notata_train_lines=None):
160
+
161
+ LOAD = True
162
+ output_path = args.output_path if args.output_path is not None else args.file_path
163
+
164
+ if test_lines == None:
165
+ path1 = os.path.join(args.file_path,"dev.tsv")
166
+ path2 = os.path.join(args.file_path,"train.tsv")
167
+ path3 = os.path.join(args.file_path,"tata_dev.tsv")
168
+ path4 = os.path.join(args.file_path,"tata_train.tsv")
169
+ path5 = os.path.join(args.file_path,"notata_dev.tsv")
170
+ path6 = os.path.join(args.file_path,"notata_train.tsv")
171
+
172
+ file1 = open(path1, "r", encoding="utf-8-sig")
173
+ file2 = open(path2, "r", encoding="utf-8-sig")
174
+ file3 = open(path3, "r", encoding="utf-8-sig")
175
+ file4 = open(path4, "r", encoding="utf-8-sig")
176
+ file5 = open(path5, "r", encoding="utf-8-sig")
177
+ file6 = open(path6, "r", encoding="utf-8-sig")
178
+
179
+ test_lines = list(csv.reader(file1, delimiter="\t", quotechar=None))
180
+ train_lines = list(csv.reader(file2, delimiter="\t", quotechar=None))
181
+ tata_test_lines = list(csv.reader(file3, delimiter="\t", quotechar=None))
182
+ tata_train_lines = list(csv.reader(file4, delimiter="\t", quotechar=None))
183
+ notata_test_lines = list(csv.reader(file5, delimiter="\t", quotechar=None))
184
+ notata_train_lines = list(csv.reader(file6, delimiter="\t", quotechar=None))
185
+
186
+ LOAD = False
187
+
188
+
189
+
190
+ for kmer in range(3,7):
191
+
192
+ print(kmer)
193
+ root_path = os.path.join(output_path, str(kmer))
194
+ if not os.path.exists(root_path):
195
+ os.makedirs(root_path)
196
+
197
+ all_path = os.path.join(root_path, "all")
198
+ # tata_path = os.path.join(root_path, "tata")
199
+ notata_path = os.path.join(root_path, "notata")
200
+ os.makedirs(all_path)
201
+ # os.makedirs(tata_path)
202
+ os.makedirs(notata_path)
203
+
204
+ if LOAD:
205
+ seq_index=8
206
+ label_index=6
207
+ else:
208
+ seq_index=0
209
+ label_index=1
210
+
211
+ print("writing dev")
212
+ write_file(test_lines, os.path.join(all_path,"dev.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
213
+ print("writing train")
214
+ write_file(train_lines, os.path.join(all_path,"train.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
215
+ # print("writing tata dev")
216
+ # write_file(tata_test_lines, os.path.join(tata_path,"dev.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
217
+ # print("writing tata train")
218
+ # write_file(tata_train_lines, os.path.join(tata_path,"train.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
219
+ print("writing notata dev")
220
+ write_file(notata_test_lines, os.path.join(notata_path,"dev.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
221
+ print("writing notata train")
222
+ write_file(notata_train_lines, os.path.join(notata_path,"train.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
223
+
224
+
225
+ def Process_splice(args):
226
+ # X_train = np.load(os.path.join(args.file_path, "x_train.npy"))
227
+ # X_dev = np.load(os.path.join(args.file_path, "x_dev.npy"))
228
+ # Y_train = np.load(os.path.join(args.file_path, "y_train.npy"))
229
+ # Y_dev = np.load(os.path.join(args.file_path, "y_dev.npy"))
230
+
231
+ # assert len(X_train) == len(Y_train)
232
+ # assert len(X_dev) == len(Y_dev)
233
+
234
+ # for kmer in range(3,7):
235
+ # root_path = os.path.join(args.file_path, str(kmer))
236
+ # os.makedirs(root_path)
237
+ # f_train = open(os.path.join(root_path, "train.tsv"), "wt")
238
+ # f_dev = open(os.path.join(root_path, "dev.tsv"), "wt")
239
+ # tsv_train = csv.writer(f_train, delimiter='\t')
240
+ # tsv_dev = csv.writer(f_dev, delimiter='\t')
241
+ # tsv_train.writerow(["seq", "label"])
242
+ # tsv_dev.writerow(["seq", "label"])
243
+
244
+ # for i, seq in enumerate(X_train):
245
+ # sequence = get_kmer_sentence(str(seq), kmer)
246
+ # tsv_train.writerow([sequence, int(Y_train[i])])
247
+
248
+ # for j, seq in enumerate(X_dev):
249
+ # sequence = get_kmer_sentence(str(seq), kmer)
250
+ # tsv_dev.writerow([sequence, int(Y_dev[j])])
251
+
252
+ X_test = np.load(os.path.join(args.file_path, "x_test.npy"))
253
+ Y_test = np.load(os.path.join(args.file_path, "y_test.npy"))
254
+
255
+ assert len(X_test) == len(Y_test)
256
+
257
+ for kmer in range(3,7):
258
+ root_path = os.path.join(args.file_path, str(kmer))
259
+ os.makedirs(root_path)
260
+ f_test = open(os.path.join(root_path, "dev.tsv"), "wt")
261
+ tsv_test = csv.writer(f_test, delimiter='\t')
262
+ tsv_test.writerow(["seq", "label"])
263
+
264
+ for i, seq in enumerate(X_test):
265
+ sequence = get_kmer_sentence(str(seq), kmer)
266
+ label = int(np.where(Y_test[i]==1)[0])
267
+ tsv_test.writerow([sequence, label])
268
+
269
+
270
+ def Process_prom_core(args):
271
+ random.seed(args.seed)
272
+
273
+ tata = args.file_path + "/TATA.csv"
274
+ notata = args.file_path + "/noTATA.csv"
275
+ tata_file = open(tata, "r", encoding="utf-8-sig")
276
+ notata_file = open(notata, "r", encoding="utf-8-sig")
277
+ tata_lines = list(csv.reader(tata_file, delimiter=",", quotechar=None))[1:]
278
+ notata_lines = list(csv.reader(notata_file, delimiter=",", quotechar=None))[1:]
279
+
280
+ random.shuffle(tata_lines)
281
+ random.shuffle(notata_lines)
282
+
283
+ num_tata_test = int(0.1*len(tata_lines))
284
+ tata_test_lines = tata_lines[:num_tata_test]
285
+ num_notata_test = int(0.1*len(notata_lines))
286
+ notata_test_lines = notata_lines[:num_notata_test]
287
+
288
+ train_lines = tata_lines[num_tata_test:] + notata_lines[num_notata_test:]
289
+ if args.dev:
290
+ num_dev = int(len(rest_lines)/9.0)
291
+ dev_lines = train_lines[:num_dev]
292
+ train_lines = train_lines[num_dev:]
293
+ else:
294
+ dev_lines = tata_test_lines + notata_test_lines
295
+
296
+ print("Number train examples: %d" % (len(train_lines)))
297
+ print("Number dev examples: %d" % (len(dev_lines)))
298
+
299
+ for kmer in range(3,7):
300
+ root_path = os.path.join(args.file_path,str(kmer))
301
+ tata_path = os.path.join(root_path, "tata")
302
+ notata_path = os.path.join(root_path, "notata")
303
+ os.makedirs(tata_path)
304
+ os.makedirs(notata_path)
305
+
306
+ write_file(tata_test_lines, os.path.join(tata_path,"dev.tsv"), kmer, head=False, seq_index=1, label_index=2)
307
+ write_file(notata_test_lines, os.path.join(notata_path,"dev.tsv"), kmer, head=False, seq_index=1, label_index=2)
308
+ write_file(train_lines, os.path.join(root_path,"train.tsv"), kmer, head=False, seq_index=1, label_index=2)
309
+ write_file(dev_lines, os.path.join(root_path,"dev.tsv"), kmer, head=False, seq_index=1, label_index=2)
310
+
311
+
312
+ def Process_pair(args):
313
+ random.seed(args.seed)
314
+
315
+ root_path = args.file_path.split('/')[-1]
316
+ train_seq1_file = open(args.file_path+"/"+root_path+"_enhancer.fasta", "r")
317
+ train_seq2_file = open(args.file_path+"/"+root_path+"_promoter.fasta", "r")
318
+ train_label_file = open(args.file_path+"/"+root_path+"_label.txt", "r")
319
+ test_seq1_file = open(args.file_path+"/"+root_path+"_enhancer_test.fasta", "r")
320
+ test_seq2_file = open(args.file_path+"/"+root_path+"_promoter_test.fasta", "r")
321
+ test_label_file = open(args.file_path+"/"+root_path+"_label_test.txt", "r")
322
+
323
+ train_seq1 = train_seq1_file.readlines()
324
+ train_seq2 = train_seq2_file.readlines()
325
+ train_label = train_label_file.readlines()
326
+ test_seq1 = test_seq1_file.readlines()
327
+ test_seq2 = test_seq2_file.readlines()
328
+ test_label = test_label_file.readlines()
329
+
330
+ train_lines = []
331
+ test_lines = []
332
+ for i in range(len(train_label)):
333
+ train_lines.append([train_seq1[2*i+1], train_seq2[2*i+1], train_label[i]])
334
+ for i in range(len(test_label)):
335
+ test_lines.append([test_seq1[2*i+1], test_seq2[2*i+1], test_label[i]])
336
+
337
+ random.shuffle(train_lines)
338
+
339
+ if args.dev:
340
+ num_dev = int(len(train_lines)/10)
341
+ dev_lines = train_lines[:num_dev]
342
+ train_lines = train_lines[num_dev:]
343
+
344
+ output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(args.kmer))
345
+ if not os.path.exists(output_path):
346
+ os.makedirs(output_path)
347
+
348
+ f_train = open(os.path.join(output_path, "train.tsv"), 'wt')
349
+ train_w = csv.writer(f_train, delimiter='\t')
350
+ train_w.writerow(["seq1", "seq2", "label"])
351
+ if args.dev:
352
+ f_dev = open(os.path.join(output_path, "dev.tsv"), 'wt')
353
+ dev_w = csv.writer(f_dev, delimiter='\t')
354
+ dev_w.writerow(["seq1", "seq2", "label"])
355
+ os.makedirs(os.path.join(output_path, "test"))
356
+ f_test = open(os.path.join(output_path, "test", "dev.tsv"), 'wt')
357
+ test_w = csv.writer(f_test, delimiter='\t')
358
+ test_w.writerow(["seq1", "seq2", "label"])
359
+ else:
360
+ f_test = open(os.path.join(output_path, "dev.tsv"), 'wt')
361
+ test_w = csv.writer(f_test, delimiter='\t')
362
+ test_w.writerow(["seq1", "seq2", "label"])
363
+
364
+ def write_file_pair(lines, writer, seq1_index=0, seq2_index=1, label_index=2):
365
+ for line in lines:
366
+ seq1 = get_kmer_sentence(line[seq1_index],args.kmer)
367
+ seq2 = get_kmer_sentence(line[seq2_index],args.kmer)
368
+ writer.writerow([seq1, seq2, str(int(line[label_index]))])
369
+
370
+ write_file_pair(train_lines, train_w)
371
+ write_file_pair(test_lines, test_w)
372
+
373
+ if args.dev:
374
+ write_file_pair(dev_lines, dev_w)
375
+
376
+
377
+ def Process_p53_mut(args):
378
+ random.seed(args.seed)
379
+
380
+ dev = os.path.join(args.file_path, "dev.csv")
381
+ dev_file = open(dev, "r", encoding="utf-8-sig")
382
+
383
+ lines = list(csv.reader(dev_file, delimiter=",", quotechar=None))[1:]
384
+
385
+ print(lines[0])
386
+
387
+ for kmer in range(3, 7):
388
+ output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(kmer))
389
+ if not os.path.exists(output_path):
390
+ os.makedirs(output_path)
391
+
392
+ write_file(lines, os.path.join(output_path, "dev.tsv"), kmer, head=True, seq_index=2, label_index=None)
393
+
394
+
395
+ def Process_p53(args):
396
+ random.seed(args.seed)
397
+
398
+ train = os.path.join(args.file_path, "train.csv")
399
+ test = os.path.join(args.file_path, "test.csv")
400
+ train_file = open(train, "r", encoding="utf-8-sig")
401
+ test_file = open(test, "r", encoding="utf-8-sig")
402
+
403
+ train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
404
+ test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
405
+ lines = train_lines + test_lines
406
+
407
+ max_length = 0
408
+ for line in lines:
409
+ if len(line[2]) > max_length:
410
+ max_length = len(line[2])
411
+
412
+ random.shuffle(train_lines)
413
+ random.shuffle(test_lines)
414
+
415
+ if args.dev:
416
+ num_dev = int(len(train_lines)/9)
417
+ dev_lines = train_lines[:num_dev]
418
+ train_lines = train_lines[num_dev:]
419
+
420
+ print(train_lines[0])
421
+
422
+ for kmer in range(3, 7):
423
+ output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(kmer))
424
+ if not os.path.exists(output_path):
425
+ os.makedirs(output_path)
426
+
427
+ write_file(train_lines, os.path.join(output_path, "train.tsv"), kmer, head=True, seq_index=2, label_index=3)
428
+ if args.dev:
429
+ write_file(dev_lines, os.path.join(output_path, "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
430
+ os.makedirs(os.path.join(output_path, "test"))
431
+ write_file(test_lines, os.path.join(output_path, "test", "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
432
+ else:
433
+ write_file(test_lines, os.path.join(output_path, "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
434
+
435
+ print("max length: %d" % (max_length))
436
+
437
+
438
+ def Seperate_p53(args):
439
+ random.seed(args.seed)
440
+
441
+ train = os.path.join(args.file_path, "train.csv")
442
+ test = os.path.join(args.file_path, "test.csv")
443
+ train_file = open(train, "r", encoding="utf-8-sig")
444
+ test_file = open(test, "r", encoding="utf-8-sig")
445
+
446
+ train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
447
+ test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
448
+ lines = train_lines + test_lines
449
+
450
+ POS = []
451
+ NEG = []
452
+
453
+ for line in lines:
454
+ if str(line[-1]) == '0':
455
+ NEG.append([line[-2], line[-1]])
456
+ else:
457
+ POS.append([line[-2], line[-1]])
458
+
459
+
460
+
461
+ for kmer in range(3,7):
462
+ os.makedirs(os.path.join(args.file_path, "POS", str(kmer)))
463
+ os.makedirs(os.path.join(args.file_path, "NEG", str(kmer)))
464
+
465
+ write_file(POS, os.path.join(args.file_path, "POS", str(kmer), "dev.tsv"), kmer=kmer, head=True, seq_index=0, label_index=1)
466
+ write_file(NEG, os.path.join(args.file_path, "NEG", str(kmer), "dev.tsv"), kmer=kmer, head=True, seq_index=0, label_index=1)
467
+
468
+
469
+
470
+ def Generate_prom_train_dev(args):
471
+ # read TATA and noTATA files
472
+ tata = args.file_path + "/noTATA_249to50.tsv"
473
+ notata = args.file_path + "/TATA_249to50.tsv"
474
+ tata_file = open(tata, "r", encoding="utf-8-sig")
475
+ notata_file = open(notata, "r", encoding="utf-8-sig")
476
+ tata_lines = list(csv.reader(tata_file, delimiter="\t", quotechar=None))[1:]
477
+ notata_lines = list(csv.reader(notata_file, delimiter="\t", quotechar=None))[1:]
478
+
479
+
480
+ # shuffle all the data and split them
481
+ random.shuffle(tata_lines)
482
+ random.shuffle(notata_lines)
483
+ num_tata_test = int(len(tata_lines)*0.1)
484
+ tata_test_lines = tata_lines[:num_tata_test]
485
+ num_notata_test = int(len(notata_lines)*0.1)
486
+ notata_test_lines = notata_lines[:num_notata_test]
487
+ train_lines = tata_lines[num_tata_test:] + notata_lines[num_notata_test:]
488
+ test_lines = tata_test_lines + notata_test_lines
489
+
490
+
491
+ write_file(train_lines, args.file_path+"/train.tsv", args.kmer)
492
+ write_file(test_lines, args.file_path+"/dev.tsv", args.kmer)
493
+ write_file(tata_test_lines, args.file_path+"/tata_dev.tsv", args.kmer)
494
+ write_file(notata_test_lines, args.file_path+"/notata_dev.tsv", args.kmer)
495
+
496
+ def Process_690(args):
497
+ path = args.file_path
498
+ all_folders = os.listdir(path)
499
+
500
+ count = 0
501
+
502
+ for folder in all_folders:
503
+ # load data
504
+ train_seq_path = os.path.join(args.file_path, folder, "train", "sequences_alph.npy")
505
+ test_seq_path = os.path.join(args.file_path, folder, "test", "sequences_alph.npy")
506
+ train_lab_path = os.path.join(args.file_path, folder, "train", "targets.npy")
507
+ test_lab_path = os.path.join(args.file_path, folder, "test", "targets.npy")
508
+ train_sequences = np.load(train_seq_path)
509
+ test_sequences = np.load(test_seq_path)
510
+ train_labels = np.load(train_lab_path)
511
+ test_labels = np.load(test_lab_path)
512
+
513
+ train_sequences = train_sequences.reshape(train_sequences.shape[0],1)
514
+ test_sequences = test_sequences.reshape(test_sequences.shape[0],1)
515
+ train_labels = train_labels.reshape(train_labels.shape[0],1)
516
+ test_labels = test_labels.reshape(test_labels.shape[0],1)
517
+
518
+ # concat sequence and labels together
519
+ trains = list(np.concatenate((train_sequences, train_labels), axis=1))
520
+ tests = list(np.concatenate((test_sequences, test_labels), axis=1))
521
+
522
+ random.seed(args.seed)
523
+ random.shuffle(trains)
524
+ random.shuffle(trains)
525
+ random.shuffle(tests)
526
+ random.shuffle(tests)
527
+
528
+
529
+ # make output path
530
+ output_path = os.path.join(args.output_path, str(args.kmer), folder)
531
+ if not os.path.exists(output_path):
532
+ os.makedirs(output_path)
533
+
534
+
535
+
536
+ # write files
537
+ f_train = open(os.path.join(output_path, "train.tsv"), 'wt')
538
+ tsv_train = csv.writer(f_train, delimiter='\t')
539
+ tsv_train.writerow(["sequence", "label"])
540
+ for i in range(len(trains)):
541
+ sentence = get_kmer_sentence(trains[i][0].decode("utf-8"), args.kmer)
542
+ tsv_train.writerow([sentence, int(trains[i][1])])
543
+
544
+ f_dev = open(os.path.join(output_path, "dev.tsv"), 'wt')
545
+ tsv_dev = csv.writer(f_dev, delimiter='\t')
546
+ tsv_dev.writerow(["sequence", "label"])
547
+ for i in range(len(tests)):
548
+ sentence = get_kmer_sentence(tests[i][0].decode("utf-8"), args.kmer)
549
+ tsv_dev.writerow([sentence, int(tests[i][1])])
550
+
551
+
552
+ count += 1
553
+ print("Finish %s folders" % (count))
554
+
555
+
556
+ def Process_mouse(args):
557
+ random.seed(args.seed)
558
+
559
+ files = os.listdir(args.file_path)
560
+
561
+ try:
562
+ files.remove("3")
563
+ files.remove("4")
564
+ files.remove("5")
565
+ files.remove("6")
566
+ except ValueError:
567
+ files = files
568
+
569
+ files.sort()
570
+ assert len(files) % 2 == 0
571
+
572
+ num_task = int(len(files)/2)
573
+
574
+ max_length = 0
575
+
576
+ for i in range(num_task):
577
+ index = str(i) if i > 9 else "0" + str(i)
578
+
579
+ test_name = files[2*i].replace("test", "train")
580
+ train_name = files[2*i+1]
581
+ assert test_name == train_name
582
+
583
+ test_file = os.path.join(args.file_path, files[2*i])
584
+ train_file = os.path.join(args.file_path, files[2*i+1])
585
+ train_file = open(train_file, "r", encoding="utf-8-sig")
586
+ test_file = open(test_file, "r", encoding="utf-8-sig")
587
+ train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
588
+ test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
589
+
590
+ print("dataset %d : %d lines" % (i, len(train_lines)))
591
+
592
+ # random.shuffle(train_lines)
593
+
594
+ # for kmer in range(3, 7):
595
+ # os.makedirs(os.path.join(args.file_path, str(kmer), index))
596
+ # write_file(train_lines, os.path.join(args.file_path, str(kmer), index, "train.tsv"), kmer, head=True, seq_index=2, label_index=3)
597
+ # write_file(test_lines, os.path.join(args.file_path, str(kmer), index, "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
598
+
599
+
600
+
601
+ def Process(args):
602
+ if args.output_path != None:
603
+ output_path = args.output_path
604
+ else:
605
+ root_path = "/".join(args.file_path.split("/")[:-1]) + "/" + str(args.kmer) + "/"
606
+ output_path = root_path + args.file_path.split("/")[-1]
607
+ if not os.path.exists(root_path):
608
+ os.makedirs(root_path)
609
+
610
+ old_file = open(args.file_path, "r", encoding="utf-8-sig")
611
+ lines = list(csv.reader(old_file, delimiter=args.delimiter, quotechar=None))
612
+
613
+ write_file(lines, output_path, args.kmer, head=args.head, seq_index=args.seq_index, label_index=args.label_index)
614
+
615
+
616
+ def main():
617
+ parser = argparse.ArgumentParser()
618
+ parser.add_argument(
619
+ "--kmer",
620
+ default=1,
621
+ type=int,
622
+ help="K-mer",
623
+ )
624
+ parser.add_argument(
625
+ "--seed",
626
+ default=24,
627
+ type=int,
628
+ help="Which random seed to use",
629
+ )
630
+ parser.add_argument(
631
+ "--task",
632
+ default="",
633
+ type=str,
634
+ help="which task to do",
635
+ )
636
+ parser.add_argument(
637
+ "--file_path",
638
+ default=None,
639
+ type=str,
640
+ help="The path of the file to be processed",
641
+ )
642
+ parser.add_argument(
643
+ "--output_path",
644
+ default=None,
645
+ type=str,
646
+ help="The path of the processed data",
647
+ )
648
+ parser.add_argument(
649
+ "--delimiter",
650
+ default=',',
651
+ type=str,
652
+ help="The path of the processed data",
653
+ )
654
+ parser.add_argument(
655
+ "--head",
656
+ action="store_true",
657
+ help="The path of the processed data",
658
+ )
659
+ parser.add_argument(
660
+ "--dev",
661
+ action="store_true",
662
+ help="Use this flag to split data as (8:1:1), else (9:1)",
663
+ )
664
+ parser.add_argument(
665
+ "--seq_index",
666
+ default=2,
667
+ type=int,
668
+ help="index of seq in the original csv file",
669
+ )
670
+ parser.add_argument(
671
+ "--label_index",
672
+ default=3,
673
+ type=int,
674
+ help="index of label in the original csv file",
675
+ )
676
+ args = parser.parse_args()
677
+
678
+ if args.task == "generate_prom":
679
+ Generate_prom_train_dev(args)
680
+ elif args.task == "shuffle":
681
+ Shuffle(args)
682
+ elif args.task == "find_train":
683
+ Find_train(args)
684
+ elif args.task == "prom_1000":
685
+ Process_1000(args)
686
+ elif args.task == "prom_1000_kmer":
687
+ Process_1000_kmer(args)
688
+ elif args.task == "splice":
689
+ Process_splice(args)
690
+ elif args.task == "pair":
691
+ Process_pair(args)
692
+ elif args.task == "p53":
693
+ Process_p53(args)
694
+ elif args.task == "p53_mut":
695
+ Process_p53_mut(args)
696
+ elif args.task == "sep_p53":
697
+ Seperate_p53(args)
698
+ elif args.task == "690":
699
+ Process_690(args)
700
+ elif args.task == "mouse":
701
+ Process_mouse(args)
702
+ elif args.task == "prom-core":
703
+ Process_prom_core(args)
704
+ else:
705
+ Process(args)
706
+
707
+
708
+
709
+
710
+
711
+
712
+ if __name__ == "__main__":
713
+ main()
examples/data_process_template/process_ner.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import csv
3
+ import os
4
+ import h5py
5
+ import numpy as np
6
+ import random
7
+ from process_pretrain_data import get_kmer_sequence
8
+ from multiprocessing import Pool
9
+
10
+
11
+ def generate_example(X, Y, kmer, index):
12
+ # assert X.shape[0] == Y.shape[0]
13
+ lines = []
14
+ for j in range(len(X)):
15
+ if j % 1000 == 0:
16
+ print("%s : %s" % (index, j))
17
+
18
+ label = list(np.zeros(200,dtype=int)) + list(np.where(Y[j]==1)[1]) + list(np.zeros(201-kmer,dtype=int))
19
+
20
+ sequence = get_kmer_sequence(X[j].decode("utf-8"), kmer)
21
+ lines.append([sequence, label])
22
+
23
+ return lines
24
+
25
+
26
+ def Process(args):
27
+ filename = args.file_path
28
+ h5 = h5py.File(filename, "r")
29
+ num_chunks = len(h5.keys())//2
30
+ keys = list(h5.keys())[:num_chunks]
31
+
32
+
33
+ X = []
34
+
35
+ for i, key in enumerate(keys):
36
+ x_key = key
37
+ y_key = x_key.replace("X","Y")
38
+
39
+ X_l = h5[x_key]
40
+ Y_l = h5[y_key][0]
41
+
42
+ X.extend(X_l)
43
+
44
+ if i == 0:
45
+ Y = Y_l
46
+ else:
47
+ Y = np.concatenate([Y, Y_l], axis=0)
48
+
49
+ print("%d : %d, %d, %s" % (i, len(X), Y.shape[0], str(key)))
50
+
51
+ print(len(X))
52
+ print(len(Y))
53
+
54
+ n_proc = int(args.n_process)
55
+ print("number of processes for converting feature: " + str(n_proc))
56
+ p = Pool(n_proc)
57
+ indexes = [0]
58
+ len_slice = int(len(X)/n_proc)
59
+ for i in range(1, n_proc+1):
60
+ if i != n_proc:
61
+ indexes.append(len_slice*(i))
62
+ else:
63
+ indexes.append(len(X))
64
+
65
+ results = []
66
+
67
+ for i in range(n_proc):
68
+ results.append(p.apply_async(generate_example, args=(X[indexes[i]:indexes[i+1]], Y[indexes[i]:indexes[i+1]], args.kmer, i)))
69
+ print(str(i+1) + ' processor started !')
70
+
71
+ p.close()
72
+ p.join()
73
+
74
+ lines = []
75
+ for result in results:
76
+ lines.extend(result.get())
77
+
78
+
79
+ path = "/".join(args.file_path.split('/')[:-1]) + "/" + str(args.kmer) + "/train.txt"
80
+ print(path)
81
+ file = open(path, "w")
82
+ for line in lines:
83
+ for k, word in enumerate(line[0]):
84
+ file.write(str(word) + " " + str(line[1][k]) + "\n")
85
+ file.write("\n")
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument(
97
+ "--kmer",
98
+ default=1,
99
+ type=int,
100
+ help="K-mer",
101
+ )
102
+ parser.add_argument(
103
+ "--n_process",
104
+ default=24,
105
+ type=int,
106
+ help="Number of processes for data processing",
107
+ )
108
+ parser.add_argument(
109
+ "--file_path",
110
+ default=None,
111
+ type=str,
112
+ help="The path of the file to be processed",
113
+ )
114
+ parser.add_argument(
115
+ "--output_path",
116
+ default=None,
117
+ type=str,
118
+ help="The path of the processed data",
119
+ )
120
+ args = parser.parse_args()
121
+
122
+ Process(args)
123
+
124
+
125
+
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
130
+
131
+
132
+
examples/data_process_template/process_pretrain_data.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import numpy as np
4
+
5
+
6
+ def cut_no_overlap(length, kmer=1, max_prob=0.5):
7
+ cuts = []
8
+ while length:
9
+ if length <= 509+kmer:
10
+ cuts.append(length)
11
+ break
12
+ else:
13
+ if random.random() > max_prob:
14
+ cut = max(int(random.random()*(509+kmer)), 5)
15
+ else:
16
+ cut = 509+kmer
17
+ cuts.append(cut)
18
+ length -= cut
19
+
20
+ return cuts
21
+
22
+
23
+ def sampling(length, kmer=1, sampling_rate=1):
24
+ times = int(length*sampling_rate/256)
25
+ starts = []
26
+ ends = []
27
+ for i in range(times):
28
+ cut = max(int(random.random()*(509+kmer)), 5)
29
+ start = np.random.randint(length-kmer)
30
+ starts.append(start)
31
+ ends.append(start+cut)
32
+
33
+ return starts, ends
34
+
35
+
36
+ def sampling_fix(length, kmer=1, sampling_rate=1, fix_length=10245):
37
+ times = int(length*sampling_rate/fix_length)
38
+ starts = []
39
+ ends = []
40
+ for i in range(times):
41
+ cut = fix_length
42
+ start = np.random.randint(length-6-fix_length)
43
+ starts.append(start)
44
+ ends.append(start+cut)
45
+
46
+ return starts, ends
47
+
48
+
49
+ def get_kmer_sentence(original_string, kmer=1, stride=1):
50
+ if kmer == -1:
51
+ return original_string
52
+
53
+ sentence = ""
54
+ original_string = original_string.replace("\n", "")
55
+ i = 0
56
+ while i < len(original_string)-kmer:
57
+ sentence += original_string[i:i+kmer] + " "
58
+ i += stride
59
+
60
+ return sentence[:-1].strip("\"")
61
+
62
+
63
+
64
+ def get_kmer_sequence(original_string, kmer=1):
65
+ if kmer == -1:
66
+ return original_string
67
+
68
+ sequence = []
69
+ original_string = original_string.replace("\n", "")
70
+ for i in range(len(original_string)-kmer):
71
+ sequence.append(original_string[i:i+kmer])
72
+
73
+ sequence.append(original_string[-kmer:])
74
+ return sequence
75
+
76
+ def Process(args):
77
+ old_file = open(args.file_path, "r")
78
+ if args.output_path == None:
79
+ args.output_path = args.file_path
80
+
81
+ if args.sampling_rate!=1.0:
82
+ new_file_path = args.output_path + "_sam" + str(args.kmer)
83
+ else:
84
+ new_file_path = args.output_path + "_cut" + str(args.kmer)
85
+ new_file = open(new_file_path, "w")
86
+ line = old_file.readline()
87
+ while line:
88
+ line_length = len(line)
89
+ if args.sampling_rate != 1.0:
90
+ starts, ends = sampling_fix(length=line_length, kmer=args.kmer, sampling_rate=args.sampling_rate, fix_length=args.length)
91
+ for i in range(len(starts)):
92
+ new_line = line[starts[i]:ends[i]]
93
+ sentence = get_kmer_sentence(new_line, kmer=args.kmer)
94
+ new_file.write(sentence + "\n")
95
+
96
+ else:
97
+ cuts = cut_no_overlap(length=line_length, kmer=args.kmer)
98
+ start = 0
99
+ for cut in cuts:
100
+ new_line = line[start:start+cut]
101
+ sentence = get_kmer_sentence(new_line, kmer=args.kmer)
102
+ start += cut
103
+ new_file.write(sentence + "\n")
104
+
105
+ line = old_file.readline()
106
+
107
+
108
+ def main():
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument(
111
+ "--sampling_rate",
112
+ default=1.0,
113
+ type=float,
114
+ help="We will sample sampling_rate*total_length*2/512 times",
115
+ )
116
+ parser.add_argument(
117
+ "--kmer",
118
+ default=1,
119
+ type=int,
120
+ help="K-mer",
121
+ )
122
+ parser.add_argument(
123
+ "--length",
124
+ default=10000,
125
+ type=int,
126
+ help="Length of the sampled sequence",
127
+ )
128
+ parser.add_argument(
129
+ "--file_path",
130
+ default=None,
131
+ type=str,
132
+ help="The path of the file to be processed",
133
+ )
134
+ parser.add_argument(
135
+ "--output_path",
136
+ default=None,
137
+ type=str,
138
+ help="The path of the processed data",
139
+ )
140
+ args = parser.parse_args()
141
+
142
+ Process(args)
143
+
144
+
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
examples/data_process_template/process_pretrain_data_multi.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Pool
2
+ import copy
3
+ import argparse
4
+
5
+ from process_pretrain_data import Process
6
+
7
+ # filenames = ['xaa', 'xab', 'xac', 'xad', 'xae', 'xaf', 'xag', 'xah', 'xai', 'xaj', 'xak', 'xal', 'xam', 'xan', 'xao', 'xap', 'xaq', 'xar', 'xas', 'xat', 'xau', 'xav', 'xaw']
8
+ # filenames = ['xaa', 'xab']
9
+
10
+ def main():
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument(
14
+ "--sampling_rate",
15
+ default=1.0,
16
+ type=float,
17
+ help="We will sample sampling_rate*total_length*2/512 times",
18
+ )
19
+ parser.add_argument(
20
+ "--kmer",
21
+ default=1,
22
+ type=int,
23
+ help="K-mer",
24
+ )
25
+ parser.add_argument(
26
+ "--length",
27
+ default=10000,
28
+ type=int,
29
+ help="Length of the sampled sequence",
30
+ )
31
+ parser.add_argument(
32
+ "--file_path",
33
+ default=None,
34
+ type=str,
35
+ help="The path of the file to be processed",
36
+ )
37
+ parser.add_argument(
38
+ "--output_path",
39
+ default="/home/zhihan/dna/data/split/",
40
+ type=str,
41
+ help="The path of the file to be processed",
42
+ )
43
+
44
+ args = parser.parse_args()
45
+
46
+ # multiprocess
47
+ p = Pool(22)
48
+
49
+ for i in range(1,23):
50
+ arg_new = copy.deepcopy(args)
51
+ arg_new.file_path = "/root/data/genome/" + "GRCh38.chr" + str(i) + ".fa"
52
+ arg_new.output_path = "/root/data/sub_001_6140/" + "GRCh38.chr" + str(i) + ".fa"
53
+ # arg_new.file_path = arg_new.output_path + filename
54
+ p.apply_async(Process, args=(arg_new,))
55
+
56
+ p.close()
57
+ p.join()
58
+
59
+
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
examples/data_process_template/process_scan_prom_data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import csv
4
+ import numpy as np
5
+ from process_pretrain_data import get_kmer_sentence
6
+
7
+
8
+
9
+
10
+ def Process(args):
11
+
12
+ SCAN_LIST = [int(500/(args.slide-1))*i for i in range(args.slide)]
13
+
14
+ old_file = open(args.file_path, "r", encoding="utf-8-sig")
15
+ old_lines = list(csv.reader(old_file, delimiter=",", quotechar=None))[1:]
16
+
17
+ if args.output_path:
18
+ root_path = args.output_path + "/"
19
+ else:
20
+ root_path = "/".join(args.file_path.split("/")[:-1]) + "/" + str(args.kmer) + "/"
21
+ if not os.path.exists(root_path):
22
+ os.makedirs(root_path)
23
+
24
+ labels = np.array([])
25
+ new_file = open(root_path+"dev.tsv", 'wt')
26
+ tsv_w = csv.writer(new_file, delimiter='\t')
27
+ tsv_w.writerow(["setence", "label"])
28
+
29
+ for line in old_lines:
30
+ label = line[6]
31
+ labels = np.append(labels, int(label))
32
+
33
+ for index in SCAN_LIST:
34
+ sub_sequence = line[8][index:index+500]
35
+ sub_sentence = get_kmer_sentence(sub_sequence, kmer=args.kmer)
36
+ tsv_w.writerow([sub_sentence, label])
37
+
38
+ np.save(root_path+"label.npy", labels)
39
+
40
+
41
+
42
+ def main():
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument(
45
+ "--kmer",
46
+ default=1,
47
+ type=int,
48
+ help="K-mer",
49
+ )
50
+ parser.add_argument(
51
+ "--file_path",
52
+ default=None,
53
+ type=str,
54
+ help="The path of the file to be processed",
55
+ )
56
+ parser.add_argument(
57
+ "--output_path",
58
+ default=None,
59
+ type=str,
60
+ help="The path of the processed data",
61
+ )
62
+ parser.add_argument(
63
+ "--slide",
64
+ default=11,
65
+ type=int,
66
+ help="How many 500s to use for the predictes result of 1000",
67
+ )
68
+ args = parser.parse_args()
69
+
70
+ Process(args)
71
+
72
+
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
examples/gen_cCRE_emb_final.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
5
+ from Bio import SeqIO
6
+ from tqdm import tqdm
7
+
8
+ # ========== CONFIG ==========
9
+ MODEL_DIR = "/home/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000"
10
+ FASTA_DIR = "/home/n5huang/dna_token/cCRE_classes/chr1_files"
11
+ OUTPUT_DIR = "/home/n5huang/dna_token/outputs_cCREemb/"
12
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
13
+
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ MODEL_CLASSES = {"dna": (BertConfig, BertForMaskedLM, DNATokenizer)}
17
+
18
+ # ========== LOAD MODEL ==========
19
+ def load_model(model_dir):
20
+ config_class, model_class, tokenizer_class = MODEL_CLASSES['dna']
21
+ print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
22
+
23
+ config = config_class.from_pretrained(model_dir)
24
+ model = BertModel.from_pretrained(model_dir, config=config)
25
+ tokenizer = tokenizer_class.from_pretrained(model_dir)
26
+
27
+ model.to(DEVICE)
28
+ model.eval()
29
+
30
+ print(f"✅ Model loaded on {DEVICE}, vocab size = {len(tokenizer)}")
31
+ return model, tokenizer
32
+
33
+ # ========== SEQUENCE HELPERS ==========
34
+ def seq_to_kmers(seq, k=6):
35
+ seq = seq.upper().replace("N", "")
36
+ if len(seq) < k:
37
+ return ""
38
+ return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
39
+
40
+ def get_fasta_sequences(fasta_file):
41
+ sequences = []
42
+ for record in SeqIO.parse(fasta_file, "fasta"):
43
+ seq = str(record.seq).upper()
44
+ if len(seq) >= 50:
45
+ sequences.append(seq)
46
+ return sequences
47
+
48
+ # ========== EMBEDDING GENERATION ==========
49
+ def get_cls_embeddings(batch_seqs, model, tokenizer, device, max_len=512):
50
+ inputs = tokenizer.batch_encode_plus(
51
+ batch_seqs,
52
+ padding="max_length",
53
+ truncation=True,
54
+ max_length=max_len,
55
+ return_tensors="pt"
56
+ )
57
+ # Move tensors to device
58
+ inputs = {k: v.to(device) for k, v in inputs.items()}
59
+
60
+ # Forward pass
61
+ with torch.no_grad():
62
+ outputs = model(**inputs)
63
+
64
+ # Extract CLS embedding
65
+ cls_embeddings = outputs[0][:, 0, :].cpu().numpy()
66
+ return cls_embeddings
67
+
68
+ # ========== MAIN EXECUTION ==========
69
+ def main():
70
+ model, tokenizer = load_model(MODEL_DIR)
71
+
72
+ fasta_files = [f for f in os.listdir(FASTA_DIR) if f.endswith(".fa")]
73
+ print(f"\nFound {len(fasta_files)} FASTA files in {FASTA_DIR}")
74
+
75
+ for fasta_file in fasta_files:
76
+ fasta_path = os.path.join(FASTA_DIR, fasta_file)
77
+ print(f"\n🚀 Processing: {fasta_file}")
78
+
79
+ sequences = get_fasta_sequences(fasta_path)
80
+ if len(sequences) == 0:
81
+ print(f"⚠️ No valid sequences found in {fasta_file}")
82
+ continue
83
+
84
+ # --- Remove duplicates ---
85
+ unique_sequences = list(set(sequences))
86
+ if len(unique_sequences) < len(sequences):
87
+ print(f"⚠️ Removed {len(sequences) - len(unique_sequences)} duplicate sequences")
88
+
89
+ # --- Convert to k-mers ---
90
+ kmers = [seq_to_kmers(s) for s in unique_sequences if len(s) >= 6]
91
+
92
+ # --- Sanity check on tokenization ---
93
+ example_tokens = tokenizer.tokenize(kmers[0])[:10]
94
+ print(f"🔹 Example tokens: {example_tokens}")
95
+
96
+ # --- Batch embedding extraction ---
97
+ all_embs = []
98
+ batch_size = 16
99
+ for i in tqdm(range(0, len(kmers), batch_size), desc=f"Embedding {fasta_file}"):
100
+ batch = kmers[i:i+batch_size]
101
+ batch_embs = get_cls_embeddings(batch, model, tokenizer, DEVICE)
102
+ all_embs.append(batch_embs)
103
+
104
+ all_embs = np.vstack(all_embs)
105
+ out_path = os.path.join(OUTPUT_DIR, fasta_file.replace(".fa", "_emb.npy"))
106
+ np.save(out_path, all_embs)
107
+
108
+ print(f"✅ Saved {all_embs.shape} embeddings to {out_path}")
109
+
110
+ print("\n🎉 All cell-type embeddings generated successfully!")
111
+
112
+ if __name__ == "__main__":
113
+ main()
examples/load_model_test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
4
+ import argparse
5
+
6
+ # Define MODEL_CLASSES as it's required by your loadmodel function
7
+ MODEL_CLASSES = {
8
+ "dna": (BertConfig, BertForMaskedLM, DNATokenizer),
9
+ # ... (other classes omitted for brevity)
10
+ }
11
+
12
+ def loadmodel(model_dir):
13
+ config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] # Changed 'DNA' to 'dna' for Python keys
14
+ print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
15
+
16
+ # 1. Load Configuration
17
+ config = config_class.from_pretrained(
18
+ model_dir,
19
+ cache_dir = None,
20
+ )
21
+
22
+ # 2. Load Model Weights
23
+ # NOTE: Since you are extracting embeddings, we should use BertModel, not BertForMaskedLM
24
+ # BertModel is the base transformer without the MLM head.
25
+ base_model_class = BertModel if model_class == BertForMaskedLM else model_class
26
+
27
+ model = base_model_class.from_pretrained(
28
+ model_dir,
29
+ from_tf=bool(".ckpt" in model_dir),
30
+ config=config,
31
+ cache_dir= None,
32
+ )
33
+
34
+ # 3. Set Device
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ model.to(device)
37
+ model.eval() # Set model to evaluation mode
38
+ print(f"Model loaded onto device: {device}")
39
+
40
+ # 4. Load Tokenizer (using custom environment variables)
41
+ #tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")}
42
+ #tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}} # Use 'dna' key
43
+ tokenizer = tokenizer_class.from_pretrained(model_dir)
44
+ print(f"Tokenizer vocabulary size: {len(tokenizer)}")
45
+
46
+ return config, model, tokenizer
47
+
48
+ # --- Main Call ---
49
+ # Use the environment variable set in the shell as the model directory
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument("--MODEL_DIR", type=str, required=True)
52
+ args = parser.parse_args()
53
+
54
+ model_dir = args.MODEL_DIR
55
+
56
+ if model_dir != "/path/to/default":
57
+ config, model, tokenizer = loadmodel(model_dir)
58
+ print("Model and Tokenizer loaded successfully.")
59
+
60
+ embedding_layer = model.get_input_embeddings()
61
+ print(embedding_layer.weight.shape)
62
+
63
+
64
+ seq = "ACGTACGTACGT"
65
+ tokens = tokenizer.tokenize(" ".join([seq[i:i+6] for i in range(len(seq)-5)]))
66
+ print(tokens[:10])
67
+ else:
68
+ print("Error: MODEL_DIR environment variable was not set.")
69
+
examples/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorboardX
2
+ tensorboard
3
+ scikit-learn >= 0.22.2
4
+ seqeval
5
+ pyahocorasick
6
+ scipy
7
+ statsmodels
8
+ biopython
9
+ pandas
10
+ pybedtools
11
+ sentencepiece==0.1.91
examples/run_finetune.py ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
17
+
18
+
19
+ import argparse
20
+ import glob
21
+ import json
22
+ import logging
23
+ import os
24
+ import re
25
+ import shutil
26
+ import random
27
+ from multiprocessing import Pool
28
+ from typing import Dict, List, Tuple
29
+ from copy import deepcopy
30
+
31
+ import numpy as np
32
+ import torch
33
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
34
+ from torch.utils.data.distributed import DistributedSampler
35
+ from tqdm import tqdm, trange
36
+
37
+ from transformers import (
38
+ WEIGHTS_NAME,
39
+ AdamW,
40
+ AlbertConfig,
41
+ AlbertForSequenceClassification,
42
+ AlbertTokenizer,
43
+ BertConfig,
44
+ BertForSequenceClassification,
45
+ BertForLongSequenceClassification,
46
+ BertForLongSequenceClassificationCat,
47
+ BertTokenizer,
48
+ DNATokenizer,
49
+ DistilBertConfig,
50
+ DistilBertForSequenceClassification,
51
+ DistilBertTokenizer,
52
+ FlaubertConfig,
53
+ FlaubertForSequenceClassification,
54
+ FlaubertTokenizer,
55
+ RobertaConfig,
56
+ RobertaForSequenceClassification,
57
+ RobertaTokenizer,
58
+ XLMConfig,
59
+ XLMForSequenceClassification,
60
+ XLMRobertaConfig,
61
+ XLMRobertaForSequenceClassification,
62
+ XLMRobertaTokenizer,
63
+ XLMTokenizer,
64
+ XLNetConfig,
65
+ XLNetForSequenceClassification,
66
+ XLNetTokenizer,
67
+ get_linear_schedule_with_warmup,
68
+ )
69
+ from transformers import glue_compute_metrics as compute_metrics
70
+ from transformers import glue_convert_examples_to_features as convert_examples_to_features
71
+ from transformers import glue_output_modes as output_modes
72
+ from transformers import glue_processors as processors
73
+
74
+
75
+ try:
76
+ from torch.utils.tensorboard import SummaryWriter
77
+ except ImportError:
78
+ from tensorboardX import SummaryWriter
79
+
80
+
81
+ logger = logging.getLogger(__name__)
82
+
83
+ ALL_MODELS = sum(
84
+ (
85
+ tuple(conf.pretrained_config_archive_map.keys())
86
+ for conf in (
87
+ BertConfig,
88
+ XLNetConfig,
89
+ XLMConfig,
90
+ RobertaConfig,
91
+ DistilBertConfig,
92
+ AlbertConfig,
93
+ XLMRobertaConfig,
94
+ FlaubertConfig,
95
+ )
96
+ ),
97
+ (),
98
+ )
99
+
100
+ MODEL_CLASSES = {
101
+ "dna": (BertConfig, BertForSequenceClassification, DNATokenizer),
102
+ "dnalong": (BertConfig, BertForLongSequenceClassification, DNATokenizer),
103
+ "dnalongcat": (BertConfig, BertForLongSequenceClassificationCat, DNATokenizer),
104
+ "bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
105
+ "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
106
+ "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
107
+ "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
108
+ "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
109
+ "albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
110
+ "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
111
+ "flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
112
+ }
113
+
114
+ TOKEN_ID_GROUP = ["bert", "dnalong", "dnalongcat", "xlnet", "albert"]
115
+
116
+ def set_seed(args):
117
+ random.seed(args.seed)
118
+ np.random.seed(args.seed)
119
+ torch.manual_seed(args.seed)
120
+ if args.n_gpu > 0:
121
+ torch.cuda.manual_seed_all(args.seed)
122
+
123
+
124
+ def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
125
+ ordering_and_checkpoint_path = []
126
+
127
+ glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
128
+
129
+ for path in glob_checkpoints:
130
+ if use_mtime:
131
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
132
+ else:
133
+ regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
134
+ if regex_match and regex_match.groups():
135
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
136
+
137
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
138
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
139
+ return checkpoints_sorted
140
+
141
+
142
+ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
143
+ if not args.save_total_limit:
144
+ return
145
+ if args.save_total_limit <= 0:
146
+ return
147
+
148
+ # Check if we should delete older checkpoint(s)
149
+ checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
150
+ if len(checkpoints_sorted) <= args.save_total_limit:
151
+ return
152
+
153
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
154
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
155
+ for checkpoint in checkpoints_to_be_deleted:
156
+ logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
157
+ shutil.rmtree(checkpoint)
158
+
159
+ def train(args, train_dataset, model, tokenizer):
160
+ """ Train the model """
161
+ if args.local_rank in [-1, 0]:
162
+ tb_writer = SummaryWriter()
163
+
164
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
165
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
166
+ train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
167
+
168
+ if args.max_steps > 0:
169
+ t_total = args.max_steps
170
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
171
+ else:
172
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
173
+
174
+ # Prepare optimizer and schedule (linear warmup and decay)
175
+ no_decay = ["bias", "LayerNorm.weight"]
176
+ optimizer_grouped_parameters = [
177
+ {
178
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
179
+ "weight_decay": args.weight_decay,
180
+ },
181
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
182
+ ]
183
+
184
+ warmup_steps = args.warmup_steps if args.warmup_percent == 0 else int(args.warmup_percent*t_total)
185
+
186
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.beta1,args.beta2))
187
+ scheduler = get_linear_schedule_with_warmup(
188
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
189
+ )
190
+
191
+ # Check if saved optimizer or scheduler states exist
192
+ if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
193
+ os.path.join(args.model_name_or_path, "scheduler.pt")
194
+ ):
195
+ # Load in optimizer and scheduler states
196
+ optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
197
+ scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
198
+
199
+ if args.fp16:
200
+ try:
201
+ from apex import amp
202
+ except ImportError:
203
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
204
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
205
+
206
+ # multi-gpu training (should be after apex fp16 initialization)
207
+ if args.n_gpu > 1:
208
+ model = torch.nn.DataParallel(model)
209
+
210
+ # Distributed training (should be after apex fp16 initialization)
211
+ if args.local_rank != -1:
212
+ model = torch.nn.parallel.DistributedDataParallel(
213
+ model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
214
+ )
215
+
216
+ # Train!
217
+ logger.info("***** Running training *****")
218
+ logger.info(" Num examples = %d", len(train_dataset))
219
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
220
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
221
+ logger.info(
222
+ " Total train batch size (w. parallel, distributed & accumulation) = %d",
223
+ args.train_batch_size
224
+ * args.gradient_accumulation_steps
225
+ * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
226
+ )
227
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
228
+ logger.info(" Total optimization steps = %d", t_total)
229
+
230
+ global_step = 0
231
+ epochs_trained = 0
232
+ steps_trained_in_current_epoch = 0
233
+ # Check if continuing training from a checkpoint
234
+ if os.path.exists(args.model_name_or_path):
235
+ # set global_step to gobal_step of last saved checkpoint from model path
236
+ try:
237
+ global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
238
+ except:
239
+ global_step = 0
240
+ epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
241
+ steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
242
+
243
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
244
+ logger.info(" Continuing training from epoch %d", epochs_trained)
245
+ logger.info(" Continuing training from global step %d", global_step)
246
+ logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
247
+
248
+ tr_loss, logging_loss = 0.0, 0.0
249
+ model.zero_grad()
250
+ train_iterator = trange(
251
+ epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
252
+ )
253
+ set_seed(args) # Added here for reproductibility
254
+
255
+ best_auc = 0
256
+ last_auc = 0
257
+ stop_count = 0
258
+
259
+ for _ in train_iterator:
260
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
261
+ for step, batch in enumerate(epoch_iterator):
262
+
263
+ # Skip past any already trained steps if resuming training
264
+ if steps_trained_in_current_epoch > 0:
265
+ steps_trained_in_current_epoch -= 1
266
+ continue
267
+
268
+ model.train()
269
+ batch = tuple(t.to(args.device) for t in batch)
270
+ inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
271
+ if args.model_type != "distilbert":
272
+ inputs["token_type_ids"] = (
273
+ batch[2] if args.model_type in TOKEN_ID_GROUP else None
274
+ ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
275
+ outputs = model(**inputs)
276
+ loss = outputs[0] # model outputs are always tuple in transformers (see doc)
277
+
278
+ if args.n_gpu > 1:
279
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
280
+ if args.gradient_accumulation_steps > 1:
281
+ loss = loss / args.gradient_accumulation_steps
282
+
283
+ if args.fp16:
284
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
285
+ scaled_loss.backward()
286
+ else:
287
+ loss.backward()
288
+
289
+ tr_loss += loss.item()
290
+ if (step + 1) % args.gradient_accumulation_steps == 0:
291
+ if args.fp16:
292
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
293
+ else:
294
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
295
+
296
+ optimizer.step()
297
+ scheduler.step() # Update learning rate schedule
298
+ model.zero_grad()
299
+ global_step += 1
300
+
301
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
302
+ logs = {}
303
+ if (
304
+ args.local_rank == -1 and args.evaluate_during_training
305
+ ): # Only evaluate when single GPU otherwise metrics may not average well
306
+ results = evaluate(args, model, tokenizer)
307
+
308
+
309
+ if args.task_name == "dna690":
310
+ # record the best auc
311
+ if results["auc"] > best_auc:
312
+ best_auc = results["auc"]
313
+
314
+ if args.early_stop != 0:
315
+ # record current auc to perform early stop
316
+ if results["auc"] < last_auc:
317
+ stop_count += 1
318
+ else:
319
+ stop_count = 0
320
+
321
+ last_auc = results["auc"]
322
+
323
+ if stop_count == args.early_stop:
324
+ logger.info("Early stop")
325
+ return global_step, tr_loss / global_step
326
+
327
+
328
+ for key, value in results.items():
329
+ eval_key = "eval_{}".format(key)
330
+ logs[eval_key] = value
331
+
332
+ loss_scalar = (tr_loss - logging_loss) / args.logging_steps
333
+ learning_rate_scalar = scheduler.get_lr()[0]
334
+ logs["learning_rate"] = learning_rate_scalar
335
+ logs["loss"] = loss_scalar
336
+ logging_loss = tr_loss
337
+
338
+ for key, value in logs.items():
339
+ tb_writer.add_scalar(key, value, global_step)
340
+ print(json.dumps({**logs, **{"step": global_step}}))
341
+
342
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
343
+ if args.task_name == "dna690" and results["auc"] < best_auc:
344
+ continue
345
+ checkpoint_prefix = "checkpoint"
346
+ # Save model checkpoint
347
+ output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
348
+ if not os.path.exists(output_dir):
349
+ os.makedirs(output_dir)
350
+ model_to_save = (
351
+ model.module if hasattr(model, "module") else model
352
+ ) # Take care of distributed/parallel training
353
+ model_to_save.save_pretrained(output_dir)
354
+ tokenizer.save_pretrained(output_dir)
355
+
356
+ logger.info("Saving model checkpoint to %s", output_dir)
357
+
358
+ _rotate_checkpoints(args, checkpoint_prefix)
359
+
360
+ if args.task_name != "dna690":
361
+ torch.save(args, os.path.join(output_dir, "training_args.bin"))
362
+ torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
363
+ torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
364
+ logger.info("Saving optimizer and scheduler states to %s", output_dir)
365
+
366
+ if args.max_steps > 0 and global_step > args.max_steps:
367
+ epoch_iterator.close()
368
+ break
369
+ if args.max_steps > 0 and global_step > args.max_steps:
370
+ train_iterator.close()
371
+ break
372
+
373
+ if args.local_rank in [-1, 0]:
374
+ tb_writer.close()
375
+
376
+ return global_step, tr_loss / global_step
377
+
378
+
379
+ def evaluate(args, model, tokenizer, prefix="", evaluate=True):
380
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
381
+ eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
382
+ eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
383
+ if args.task_name[:3] == "dna":
384
+ softmax = torch.nn.Softmax(dim=1)
385
+
386
+
387
+ results = {}
388
+ for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
389
+ eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=evaluate)
390
+
391
+ if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
392
+ os.makedirs(eval_output_dir)
393
+
394
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
395
+ # Note that DistributedSampler samples randomly
396
+ eval_sampler = SequentialSampler(eval_dataset)
397
+ eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
398
+
399
+ # multi-gpu eval
400
+ if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
401
+ model = torch.nn.DataParallel(model)
402
+
403
+ # Eval!
404
+ logger.info("***** Running evaluation {} *****".format(prefix))
405
+ logger.info(" Num examples = %d", len(eval_dataset))
406
+ logger.info(" Batch size = %d", args.eval_batch_size)
407
+ eval_loss = 0.0
408
+ nb_eval_steps = 0
409
+ preds = None
410
+ probs = None
411
+ out_label_ids = None
412
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
413
+ model.eval()
414
+ batch = tuple(t.to(args.device) for t in batch)
415
+
416
+ with torch.no_grad():
417
+ inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
418
+ if args.model_type != "distilbert":
419
+ inputs["token_type_ids"] = (
420
+ batch[2] if args.model_type in TOKEN_ID_GROUP else None
421
+ ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
422
+ outputs = model(**inputs)
423
+ tmp_eval_loss, logits = outputs[:2]
424
+
425
+ eval_loss += tmp_eval_loss.mean().item()
426
+ nb_eval_steps += 1
427
+ if preds is None:
428
+ preds = logits.detach().cpu().numpy()
429
+ out_label_ids = inputs["labels"].detach().cpu().numpy()
430
+ else:
431
+ preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
432
+ out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
433
+
434
+ eval_loss = eval_loss / nb_eval_steps
435
+ if args.output_mode == "classification":
436
+ if args.task_name[:3] == "dna" and args.task_name != "dnasplice":
437
+ if args.do_ensemble_pred:
438
+ probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
439
+ else:
440
+ probs = softmax(torch.tensor(preds, dtype=torch.float32))[:,1].numpy()
441
+ elif args.task_name == "dnasplice":
442
+ probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
443
+ preds = np.argmax(preds, axis=1)
444
+ elif args.output_mode == "regression":
445
+ preds = np.squeeze(preds)
446
+ if args.do_ensemble_pred:
447
+ result = compute_metrics(eval_task, preds, out_label_ids, probs[:,1])
448
+ else:
449
+ result = compute_metrics(eval_task, preds, out_label_ids, probs)
450
+ results.update(result)
451
+
452
+ if args.task_name == "dna690":
453
+ eval_output_dir = args.result_dir
454
+ if not os.path.exists(args.result_dir):
455
+ os.makedirs(args.result_dir)
456
+ output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
457
+ with open(output_eval_file, "a") as writer:
458
+
459
+ if args.task_name[:3] == "dna":
460
+ eval_result = args.data_dir.split('/')[-1] + " "
461
+ else:
462
+ eval_result = ""
463
+
464
+ logger.info("***** Eval results {} *****".format(prefix))
465
+ for key in sorted(result.keys()):
466
+ logger.info(" %s = %s", key, str(result[key]))
467
+ eval_result = eval_result + str(result[key])[:5] + " "
468
+ writer.write(eval_result + "\n")
469
+
470
+ if args.do_ensemble_pred:
471
+ return results, eval_task, preds, out_label_ids, probs
472
+ else:
473
+ return results
474
+
475
+
476
+
477
+ def predict(args, model, tokenizer, prefix=""):
478
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
479
+ pred_task_names = (args.task_name,)
480
+ pred_outputs_dirs = (args.predict_dir,)
481
+ if not os.path.exists(args.predict_dir):
482
+ os.makedirs(args.predict_dir)
483
+ softmax = torch.nn.Softmax(dim=1)
484
+
485
+ predictions = {}
486
+ for pred_task, pred_output_dir in zip(pred_task_names, pred_outputs_dirs):
487
+ pred_dataset = load_and_cache_examples(args, pred_task, tokenizer, evaluate=True)
488
+
489
+ if not os.path.exists(pred_output_dir) and args.local_rank in [-1, 0]:
490
+ os.makedirs(pred_output_dir)
491
+
492
+ args.pred_batch_size = args.per_gpu_pred_batch_size * max(1, args.n_gpu)
493
+ # Note that DistributedSampler samples randomly
494
+ pred_sampler = SequentialSampler(pred_dataset)
495
+ pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler, batch_size=args.pred_batch_size)
496
+
497
+ # multi-gpu eval
498
+ if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
499
+ model = torch.nn.DataParallel(model)
500
+
501
+ # Eval!
502
+ logger.info("***** Running prediction {} *****".format(prefix))
503
+ logger.info(" Num examples = %d", len(pred_dataset))
504
+ logger.info(" Batch size = %d", args.pred_batch_size)
505
+ pred_loss = 0.0
506
+ nb_pred_steps = 0
507
+ preds = None
508
+ out_label_ids = None
509
+ for batch in tqdm(pred_dataloader, desc="Predicting"):
510
+ model.eval()
511
+ batch = tuple(t.to(args.device) for t in batch)
512
+
513
+ with torch.no_grad():
514
+ inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
515
+ if args.model_type != "distilbert":
516
+ inputs["token_type_ids"] = (
517
+ batch[2] if args.model_type in TOKEN_ID_GROUP else None
518
+ ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
519
+ outputs = model(**inputs)
520
+ _, logits = outputs[:2]
521
+
522
+ if preds is None:
523
+ preds = logits.detach().cpu().numpy()
524
+ out_label_ids = inputs["labels"].detach().cpu().numpy()
525
+ else:
526
+ preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
527
+ out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
528
+
529
+ if args.output_mode == "classification":
530
+ if args.task_name[:3] == "dna" and args.task_name != "dnasplice":
531
+ if args.do_ensemble_pred:
532
+ probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
533
+ else:
534
+ probs = softmax(torch.tensor(preds, dtype=torch.float32))[:,1].numpy()
535
+ elif args.task_name == "dnasplice":
536
+ probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
537
+ preds = np.argmax(preds, axis=1)
538
+ elif args.output_mode == "regression":
539
+ preds = np.squeeze(preds)
540
+
541
+ if args.do_ensemble_pred:
542
+ result = compute_metrics(pred_task, preds, out_label_ids, probs[:,1])
543
+ else:
544
+ result = compute_metrics(pred_task, preds, out_label_ids, probs)
545
+
546
+ pred_output_dir = args.predict_dir
547
+ if not os.path.exists(pred_output_dir):
548
+ os.makedir(pred_output_dir)
549
+ output_pred_file = os.path.join(pred_output_dir, "pred_results.npy")
550
+ logger.info("***** Pred results {} *****".format(prefix))
551
+ for key in sorted(result.keys()):
552
+ logger.info(" %s = %s", key, str(result[key]))
553
+ np.save(output_pred_file, probs)
554
+
555
+
556
+ def format_attention(attention):
557
+ squeezed = []
558
+ for layer_attention in attention:
559
+ # 1 x num_heads x seq_len x seq_len
560
+ if len(layer_attention.shape) != 4:
561
+ raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
562
+ "output_attentions=True when initializing your model.")
563
+ squeezed.append(layer_attention.squeeze(0))
564
+ # num_layers x num_heads x seq_len x seq_len
565
+ return torch.stack(squeezed).unsqueeze(0)
566
+
567
+
568
+ def visualize(args, model, tokenizer, kmer, prefix=""):
569
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
570
+ pred_task_names = (args.task_name,)
571
+ pred_outputs_dirs = (args.predict_dir,)
572
+ if not os.path.exists(args.predict_dir):
573
+ os.makedirs(args.predict_dir)
574
+ softmax = torch.nn.Softmax(dim=1)
575
+
576
+
577
+ for pred_task, pred_output_dir in zip(pred_task_names, pred_outputs_dirs):
578
+ '''
579
+ if args.task_name != "dna690":
580
+ args.data_dir = os.path.join(args.visualize_data_dir, str(kmer))
581
+ else:
582
+ args.data_dir = deepcopy(args.visualize_data_dir).replace("/690", "/690/" + str(kmer))
583
+ '''
584
+
585
+
586
+ evaluate = False if args.visualize_train else True
587
+ pred_dataset = load_and_cache_examples(args, pred_task, tokenizer, evaluate=evaluate)
588
+
589
+ if not os.path.exists(pred_output_dir) and args.local_rank in [-1, 0]:
590
+ os.makedirs(pred_output_dir)
591
+
592
+ args.pred_batch_size = args.per_gpu_pred_batch_size * max(1, args.n_gpu)
593
+ # Note that DistributedSampler samples randomly
594
+ pred_sampler = SequentialSampler(pred_dataset)
595
+ pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler, batch_size=args.pred_batch_size)
596
+
597
+ # multi-gpu eval
598
+ if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
599
+ model = torch.nn.DataParallel(model)
600
+
601
+ # Eval!
602
+ logger.info("***** Running prediction {} *****".format(prefix))
603
+ logger.info(" Num examples = %d", len(pred_dataset))
604
+ logger.info(" Batch size = %d", args.pred_batch_size)
605
+ pred_loss = 0.0
606
+ nb_pred_steps = 0
607
+ batch_size = args.pred_batch_size
608
+ if args.task_name != "dnasplice":
609
+ preds = np.zeros([len(pred_dataset),2])
610
+ else:
611
+ preds = np.zeros([len(pred_dataset),3])
612
+ attention_scores = np.zeros([len(pred_dataset), 12, args.max_seq_length, args.max_seq_length])
613
+
614
+ for index, batch in enumerate(tqdm(pred_dataloader, desc="Predicting")):
615
+ model.eval()
616
+ batch = tuple(t.to(args.device) for t in batch)
617
+
618
+ with torch.no_grad():
619
+ inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
620
+ if args.model_type != "distilbert":
621
+ inputs["token_type_ids"] = (
622
+ batch[2] if args.model_type in TOKEN_ID_GROUP else None
623
+ ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
624
+ outputs = model(**inputs)
625
+ attention = outputs[-1][-1]
626
+ _, logits = outputs[:2]
627
+
628
+
629
+ preds[index*batch_size:index*batch_size+len(batch[0]),:] = logits.detach().cpu().numpy()
630
+ attention_scores[index*batch_size:index*batch_size+len(batch[0]),:,:,:] = attention.cpu().numpy()
631
+ # if preds is None:
632
+ # preds = logits.detach().cpu().numpy()
633
+ # else:
634
+ # preds = np.concatenate((preds, logits.detach().cpu().numpy()), axis=0)
635
+
636
+ # if attention_scores is not None:
637
+ # attention_scores = np.concatenate((attention_scores, attention.cpu().numpy()), 0)
638
+ # else:
639
+ # attention_scores = attention.cpu().numpy()
640
+
641
+ if args.task_name != "dnasplice":
642
+ probs = softmax(torch.tensor(preds, dtype=torch.float32))[:,1].numpy()
643
+ else:
644
+ probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
645
+
646
+ scores = np.zeros([attention_scores.shape[0], attention_scores.shape[-1]])
647
+
648
+ for index, attention_score in enumerate(attention_scores):
649
+ attn_score = []
650
+ for i in range(1, attention_score.shape[-1]-kmer+2):
651
+ attn_score.append(float(attention_score[:,0,i].sum()))
652
+
653
+ for i in range(len(attn_score)-1):
654
+ if attn_score[i+1] == 0:
655
+ attn_score[i] = 0
656
+ break
657
+
658
+ # attn_score[0] = 0
659
+ counts = np.zeros([len(attn_score)+kmer-1])
660
+ real_scores = np.zeros([len(attn_score)+kmer-1])
661
+ for i, score in enumerate(attn_score):
662
+ for j in range(kmer):
663
+ counts[i+j] += 1.0
664
+ real_scores[i+j] += score
665
+ real_scores = real_scores / counts
666
+ real_scores = real_scores / np.linalg.norm(real_scores)
667
+
668
+
669
+ # print(index)
670
+ # print(real_scores)
671
+ # print(len(real_scores))
672
+
673
+ scores[index] = real_scores
674
+
675
+
676
+ return scores, probs
677
+
678
+
679
+
680
+ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
681
+ if args.local_rank not in [-1, 0] and not evaluate:
682
+ torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
683
+
684
+ processor = processors[task]()
685
+ output_mode = output_modes[task]
686
+ # Load data features from cache or dataset file
687
+ cached_features_file = os.path.join(
688
+ args.data_dir,
689
+ "cached_{}_{}_{}_{}".format(
690
+ "dev" if evaluate else "train",
691
+ list(filter(None, args.model_name_or_path.split("/"))).pop(),
692
+ str(args.max_seq_length),
693
+ str(task),
694
+ ),
695
+ )
696
+ if args.do_predict:
697
+ cached_features_file = os.path.join(
698
+ args.data_dir,
699
+ "cached_{}_{}_{}".format(
700
+ "dev" if evaluate else "train",
701
+ str(args.max_seq_length),
702
+ str(task),
703
+ ),
704
+ )
705
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
706
+ logger.info("Loading features from cached file %s", cached_features_file)
707
+ features = torch.load(cached_features_file)
708
+ else:
709
+ logger.info("Creating features from dataset file at %s", args.data_dir)
710
+ label_list = processor.get_labels()
711
+ if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
712
+ # HACK(label indices are swapped in RoBERTa pretrained model)
713
+ label_list[1], label_list[2] = label_list[2], label_list[1]
714
+ examples = (
715
+ processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
716
+ )
717
+
718
+
719
+ print("finish loading examples")
720
+
721
+ # params for convert_examples_to_features
722
+ max_length = args.max_seq_length
723
+ pad_on_left = bool(args.model_type in ["xlnet"])
724
+ pad_token = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
725
+ pad_token_segment_id = 4 if args.model_type in ["xlnet"] else 0
726
+
727
+
728
+ if args.n_process == 1:
729
+ features = convert_examples_to_features(
730
+ examples,
731
+ tokenizer,
732
+ label_list=label_list,
733
+ max_length=max_length,
734
+ output_mode=output_mode,
735
+ pad_on_left=pad_on_left, # pad on the left for xlnet
736
+ pad_token=pad_token,
737
+ pad_token_segment_id=pad_token_segment_id,)
738
+
739
+ else:
740
+ n_proc = int(args.n_process)
741
+ if evaluate:
742
+ n_proc = max(int(n_proc/4),1)
743
+ print("number of processes for converting feature: " + str(n_proc))
744
+ p = Pool(n_proc)
745
+ indexes = [0]
746
+ len_slice = int(len(examples)/n_proc)
747
+ for i in range(1, n_proc+1):
748
+ if i != n_proc:
749
+ indexes.append(len_slice*(i))
750
+ else:
751
+ indexes.append(len(examples))
752
+
753
+ results = []
754
+
755
+ for i in range(n_proc):
756
+ results.append(p.apply_async(convert_examples_to_features, args=(examples[indexes[i]:indexes[i+1]], tokenizer, max_length, None, label_list, output_mode, pad_on_left, pad_token, pad_token_segment_id, True, )))
757
+ print(str(i+1) + ' processor started !')
758
+
759
+ p.close()
760
+ p.join()
761
+
762
+ features = []
763
+ for result in results:
764
+ features.extend(result.get())
765
+
766
+
767
+ if args.local_rank in [-1, 0]:
768
+ logger.info("Saving features into cached file %s", cached_features_file)
769
+ torch.save(features, cached_features_file)
770
+
771
+ if args.local_rank == 0 and not evaluate:
772
+ torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
773
+
774
+ # Convert to Tensors and build dataset
775
+ all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
776
+ all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
777
+ all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
778
+ if output_mode == "classification":
779
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
780
+ elif output_mode == "regression":
781
+ all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
782
+
783
+ dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
784
+ return dataset
785
+
786
+
787
+ def main():
788
+ parser = argparse.ArgumentParser()
789
+
790
+ # Required parameters
791
+ parser.add_argument(
792
+ "--data_dir",
793
+ default=None,
794
+ type=str,
795
+ required=True,
796
+ help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
797
+ )
798
+ parser.add_argument(
799
+ "--model_type",
800
+ default=None,
801
+ type=str,
802
+ required=True,
803
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
804
+ )
805
+ parser.add_argument(
806
+ "--n_process",
807
+ default=2,
808
+ type=int,
809
+ help="number of processes used for data process",
810
+ )
811
+ parser.add_argument(
812
+ "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
813
+ )
814
+ parser.add_argument(
815
+ "--model_name_or_path",
816
+ default=None,
817
+ type=str,
818
+ required=True,
819
+ help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
820
+ )
821
+ parser.add_argument(
822
+ "--task_name",
823
+ default=None,
824
+ type=str,
825
+ required=True,
826
+ help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
827
+ )
828
+ parser.add_argument(
829
+ "--output_dir",
830
+ default=None,
831
+ type=str,
832
+ required=True,
833
+ help="The output directory where the model predictions and checkpoints will be written.",
834
+ )
835
+
836
+
837
+ # Other parameters
838
+ parser.add_argument(
839
+ "--visualize_data_dir",
840
+ default=None,
841
+ type=str,
842
+ help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
843
+ )
844
+ parser.add_argument(
845
+ "--result_dir",
846
+ default=None,
847
+ type=str,
848
+ help="The directory where the dna690 and mouse will save results.",
849
+ )
850
+ parser.add_argument(
851
+ "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
852
+ )
853
+ parser.add_argument(
854
+ "--tokenizer_name",
855
+ default="",
856
+ type=str,
857
+ help="Pretrained tokenizer name or path if not the same as model_name",
858
+ )
859
+ parser.add_argument(
860
+ "--cache_dir",
861
+ default="",
862
+ type=str,
863
+ help="Where do you want to store the pre-trained models downloaded from s3",
864
+ )
865
+ parser.add_argument(
866
+ "--predict_dir",
867
+ default=None,
868
+ type=str,
869
+ help="The output directory of predicted result. (when do_predict)",
870
+ )
871
+ parser.add_argument(
872
+ "--max_seq_length",
873
+ default=128,
874
+ type=int,
875
+ help="The maximum total input sequence length after tokenization. Sequences longer "
876
+ "than this will be truncated, sequences shorter will be padded.",
877
+ )
878
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
879
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
880
+ parser.add_argument("--do_predict", action="store_true", help="Whether to do prediction on the given dataset.")
881
+ parser.add_argument("--do_visualize", action="store_true", help="Whether to calculate attention score.")
882
+ parser.add_argument("--visualize_train", action="store_true", help="Whether to visualize train.tsv or dev.tsv.")
883
+ parser.add_argument("--do_ensemble_pred", action="store_true", help="Whether to do ensemble prediction with kmer 3456.")
884
+ parser.add_argument(
885
+ "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
886
+ )
887
+ parser.add_argument(
888
+ "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
889
+ )
890
+
891
+ parser.add_argument(
892
+ "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
893
+ )
894
+ parser.add_argument(
895
+ "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
896
+ )
897
+ parser.add_argument(
898
+ "--per_gpu_pred_batch_size", default=8, type=int, help="Batch size per GPU/CPU for prediction.",
899
+ )
900
+ parser.add_argument(
901
+ "--early_stop", default=0, type=int, help="set this to a positive integet if you want to perfrom early stop. The model will stop \
902
+ if the auc keep decreasing early_stop times",
903
+ )
904
+ parser.add_argument(
905
+ "--predict_scan_size",
906
+ type=int,
907
+ default=1,
908
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
909
+ )
910
+ parser.add_argument(
911
+ "--gradient_accumulation_steps",
912
+ type=int,
913
+ default=1,
914
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
915
+ )
916
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
917
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
918
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
919
+ parser.add_argument("--beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
920
+ parser.add_argument("--beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
921
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
922
+ parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate of attention.")
923
+ parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate of intermidiete layer.")
924
+ parser.add_argument("--rnn_dropout", default=0.0, type=float, help="Dropout rate of intermidiete layer.")
925
+ parser.add_argument("--rnn", default="lstm", type=str, help="What kind of RNN to use")
926
+ parser.add_argument("--num_rnn_layer", default=2, type=int, help="Number of rnn layers in dnalong model.")
927
+ parser.add_argument("--rnn_hidden", default=768, type=int, help="Number of hidden unit in a rnn layer.")
928
+ parser.add_argument(
929
+ "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
930
+ )
931
+ parser.add_argument(
932
+ "--max_steps",
933
+ default=-1,
934
+ type=int,
935
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
936
+ )
937
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
938
+ parser.add_argument("--warmup_percent", default=0, type=float, help="Linear warmup over warmup_percent*total_steps.")
939
+
940
+ parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
941
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
942
+ parser.add_argument(
943
+ "--save_total_limit",
944
+ type=int,
945
+ default=None,
946
+ help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
947
+ )
948
+ parser.add_argument(
949
+ "--eval_all_checkpoints",
950
+ action="store_true",
951
+ help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
952
+ )
953
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
954
+ parser.add_argument(
955
+ "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
956
+ )
957
+ parser.add_argument(
958
+ "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
959
+ )
960
+ parser.add_argument(
961
+ "--visualize_models", type=int, default=None, help="The model used to do visualization. If None, use 3456.",
962
+ )
963
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
964
+
965
+
966
+ parser.add_argument(
967
+ "--fp16",
968
+ action="store_true",
969
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
970
+ )
971
+ parser.add_argument(
972
+ "--fp16_opt_level",
973
+ type=str,
974
+ default="O1",
975
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
976
+ "See details at https://nvidia.github.io/apex/amp.html",
977
+ )
978
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
979
+ parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
980
+ parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
981
+
982
+
983
+ args = parser.parse_args()
984
+
985
+ if args.should_continue:
986
+ sorted_checkpoints = _sorted_checkpoints(args)
987
+ if len(sorted_checkpoints) == 0:
988
+ raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
989
+ else:
990
+ args.model_name_or_path = sorted_checkpoints[-1]
991
+
992
+ if (
993
+ os.path.exists(args.output_dir)
994
+ and os.listdir(args.output_dir)
995
+ and args.do_train
996
+ and not args.overwrite_output_dir
997
+ ):
998
+ raise ValueError(
999
+ "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
1000
+ args.output_dir
1001
+ )
1002
+ )
1003
+
1004
+ # Setup distant debugging if needed
1005
+ if args.server_ip and args.server_port:
1006
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
1007
+ import ptvsd
1008
+
1009
+ print("Waiting for debugger attach")
1010
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
1011
+ ptvsd.wait_for_attach()
1012
+
1013
+ # Setup CUDA, GPU & distributed training
1014
+ if args.local_rank == -1 or args.no_cuda:
1015
+ device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1016
+ args.n_gpu = torch.cuda.device_count()
1017
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1018
+ torch.cuda.set_device(args.local_rank)
1019
+ device = torch.device("cuda", args.local_rank)
1020
+ torch.distributed.init_process_group(backend="nccl")
1021
+ args.n_gpu = 1
1022
+ args.device = device
1023
+
1024
+ # Setup logging
1025
+ logging.basicConfig(
1026
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
1027
+ datefmt="%m/%d/%Y %H:%M:%S",
1028
+ level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
1029
+ )
1030
+ logger.warning(
1031
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
1032
+ args.local_rank,
1033
+ device,
1034
+ args.n_gpu,
1035
+ bool(args.local_rank != -1),
1036
+ args.fp16,
1037
+ )
1038
+
1039
+ # Set seed
1040
+ set_seed(args)
1041
+
1042
+ # Prepare GLUE task
1043
+ args.task_name = args.task_name.lower()
1044
+ if args.task_name not in processors:
1045
+ raise ValueError("Task not found: %s" % (args.task_name))
1046
+ processor = processors[args.task_name]()
1047
+ args.output_mode = output_modes[args.task_name]
1048
+ label_list = processor.get_labels()
1049
+ num_labels = len(label_list)
1050
+
1051
+ # Load pretrained model and tokenizer
1052
+ if args.local_rank not in [-1, 0]:
1053
+ torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
1054
+
1055
+ args.model_type = args.model_type.lower()
1056
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
1057
+
1058
+ if not args.do_visualize and not args.do_ensemble_pred:
1059
+ config = config_class.from_pretrained(
1060
+ args.config_name if args.config_name else args.model_name_or_path,
1061
+ num_labels=num_labels,
1062
+ finetuning_task=args.task_name,
1063
+ cache_dir=args.cache_dir if args.cache_dir else None,
1064
+ )
1065
+
1066
+ config.hidden_dropout_prob = args.hidden_dropout_prob
1067
+ config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
1068
+ if args.model_type in ["dnalong", "dnalongcat"]:
1069
+ assert args.max_seq_length % 512 == 0
1070
+ config.split = int(args.max_seq_length/512)
1071
+ config.rnn = args.rnn
1072
+ config.num_rnn_layer = args.num_rnn_layer
1073
+ config.rnn_dropout = args.rnn_dropout
1074
+ config.rnn_hidden = args.rnn_hidden
1075
+
1076
+ tokenizer = tokenizer_class.from_pretrained(
1077
+ args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
1078
+ do_lower_case=args.do_lower_case,
1079
+ cache_dir=args.cache_dir if args.cache_dir else None,
1080
+ )
1081
+ model = model_class.from_pretrained(
1082
+ args.model_name_or_path,
1083
+ from_tf=bool(".ckpt" in args.model_name_or_path),
1084
+ config=config,
1085
+ cache_dir=args.cache_dir if args.cache_dir else None,
1086
+ )
1087
+ logger.info('finish loading model')
1088
+
1089
+ if args.local_rank == 0:
1090
+ torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
1091
+
1092
+ model.to(args.device)
1093
+
1094
+ logger.info("Training/evaluation parameters %s", args)
1095
+
1096
+ # Training
1097
+ if args.do_train:
1098
+ train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
1099
+ global_step, tr_loss = train(args, train_dataset, model, tokenizer)
1100
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
1101
+
1102
+ # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
1103
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and args.task_name != "dna690":
1104
+ # Create output directory if needed
1105
+ if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
1106
+ os.makedirs(args.output_dir)
1107
+
1108
+ logger.info("Saving model checkpoint to %s", args.output_dir)
1109
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
1110
+ # They can then be reloaded using `from_pretrained()`
1111
+ model_to_save = (
1112
+ model.module if hasattr(model, "module") else model
1113
+ ) # Take care of distributed/parallel training
1114
+ model_to_save.save_pretrained(args.output_dir)
1115
+ tokenizer.save_pretrained(args.output_dir)
1116
+
1117
+ # Good practice: save your training arguments together with the trained model
1118
+ torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
1119
+
1120
+ # Load a trained model and vocabulary that you have fine-tuned
1121
+ model = model_class.from_pretrained(args.output_dir)
1122
+ tokenizer = tokenizer_class.from_pretrained(args.output_dir)
1123
+ model.to(args.device)
1124
+
1125
+ # Evaluation
1126
+ results = {}
1127
+ if args.do_eval and args.local_rank in [-1, 0]:
1128
+ tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
1129
+ checkpoints = [args.output_dir]
1130
+ if args.eval_all_checkpoints:
1131
+ checkpoints = list(
1132
+ os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
1133
+ )
1134
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
1135
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
1136
+ for checkpoint in checkpoints:
1137
+ global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
1138
+ prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
1139
+
1140
+ model = model_class.from_pretrained(checkpoint)
1141
+ model.to(args.device)
1142
+ result = evaluate(args, model, tokenizer, prefix=prefix)
1143
+ result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
1144
+ results.update(result)
1145
+
1146
+ # Prediction
1147
+ predictions = {}
1148
+ if args.do_predict and args.local_rank in [-1, 0]:
1149
+ tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
1150
+ checkpoint = args.output_dir
1151
+ logger.info("Predict using the following checkpoint: %s", checkpoint)
1152
+ prefix = ''
1153
+ model = model_class.from_pretrained(checkpoint)
1154
+ model.to(args.device)
1155
+ prediction = predict(args, model, tokenizer, prefix=prefix)
1156
+
1157
+ # Visualize
1158
+ if args.do_visualize and args.local_rank in [-1, 0]:
1159
+ visualization_models = [3,4,5,6] if not args.visualize_models else [args.visualize_models]
1160
+
1161
+ scores = None
1162
+ all_probs = None
1163
+
1164
+ for kmer in visualization_models:
1165
+ output_dir = args.output_dir.replace("/690", "/690/" + str(kmer))
1166
+ #checkpoint_name = os.listdir(output_dir)[0]
1167
+ #output_dir = os.path.join(output_dir, checkpoint_name)
1168
+
1169
+ tokenizer = tokenizer_class.from_pretrained(
1170
+ "dna"+str(kmer),
1171
+ do_lower_case=args.do_lower_case,
1172
+ cache_dir=args.cache_dir if args.cache_dir else None,
1173
+ )
1174
+ checkpoint = output_dir
1175
+ logger.info("Calculate attention score using the following checkpoint: %s", checkpoint)
1176
+ prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
1177
+ config = config_class.from_pretrained(
1178
+ output_dir,
1179
+ num_labels=num_labels,
1180
+ finetuning_task=args.task_name,
1181
+ cache_dir=args.cache_dir if args.cache_dir else None,
1182
+ )
1183
+ config.output_attentions = True
1184
+ model = model_class.from_pretrained(
1185
+ checkpoint,
1186
+ from_tf=bool(".ckpt" in args.model_name_or_path),
1187
+ config=config,
1188
+ cache_dir=args.cache_dir if args.cache_dir else None,
1189
+ )
1190
+ model.to(args.device)
1191
+ attention_scores, probs = visualize(args, model, tokenizer, prefix=prefix, kmer=kmer)
1192
+ if scores is not None:
1193
+ all_probs += probs
1194
+ scores += attention_scores
1195
+ else:
1196
+ all_probs = deepcopy(probs)
1197
+ scores = deepcopy(attention_scores)
1198
+
1199
+ all_probs = all_probs/float(len(visualization_models))
1200
+ np.save(os.path.join(args.predict_dir, "atten.npy"), scores)
1201
+ np.save(os.path.join(args.predict_dir, "pred_results.npy"), all_probs)
1202
+
1203
+ # ensemble prediction
1204
+ if args.do_ensemble_pred and args.local_rank in [-1, 0]:
1205
+
1206
+ for kmer in range(3,7):
1207
+ output_dir = os.path.join(args.output_dir, str(kmer))
1208
+ tokenizer = tokenizer_class.from_pretrained(
1209
+ "dna"+str(kmer),
1210
+ do_lower_case=args.do_lower_case,
1211
+ cache_dir=args.cache_dir if args.cache_dir else None,
1212
+ )
1213
+ checkpoint = output_dir
1214
+ logger.info("Calculate attention score using the following checkpoint: %s", checkpoint)
1215
+ prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
1216
+ config = config_class.from_pretrained(
1217
+ output_dir,
1218
+ num_labels=num_labels,
1219
+ finetuning_task=args.task_name,
1220
+ cache_dir=args.cache_dir if args.cache_dir else None,
1221
+ )
1222
+ config.output_attentions = True
1223
+ model = model_class.from_pretrained(
1224
+ args.model_name_or_path,
1225
+ from_tf=bool(".ckpt" in args.model_name_or_path),
1226
+ config=config,
1227
+ cache_dir=args.cache_dir if args.cache_dir else None,
1228
+ )
1229
+ model.to(args.device)
1230
+ if kmer == 3:
1231
+ args.data_dir = os.path.join(args.data_dir, str(kmer))
1232
+ else:
1233
+ args.data_dir = args.data_dir.replace("/"+str(kmer-1), "/"+str(kmer))
1234
+
1235
+ if args.result_dir.split('/')[-1] == "test.npy":
1236
+ results, eval_task, _, out_label_ids, probs = evaluate(args, model, tokenizer, prefix=prefix)
1237
+ elif args.result_dir.split('/')[-1] == "train.npy":
1238
+ results, eval_task, _, out_label_ids, probs = evaluate(args, model, tokenizer, prefix=prefix, evaluate=False)
1239
+ else:
1240
+ raise ValueError("file name in result_dir should be either test.npy or train.npy")
1241
+
1242
+ if kmer == 3:
1243
+ all_probs = deepcopy(probs)
1244
+ cat_probs = deepcopy(probs)
1245
+ else:
1246
+ all_probs += probs
1247
+ cat_probs = np.concatenate((cat_probs, probs), axis=1)
1248
+ print(cat_probs[0])
1249
+
1250
+
1251
+ all_probs = all_probs / 4.0
1252
+ all_preds = np.argmax(all_probs, axis=1)
1253
+
1254
+ # save label and data for stuck ensemble
1255
+ labels = np.array(out_label_ids)
1256
+ labels = labels.reshape(labels.shape[0],1)
1257
+ data = np.concatenate((cat_probs, labels), axis=1)
1258
+ random.shuffle(data)
1259
+ root_path = args.result_dir.replace(args.result_dir.split('/')[-1],'')
1260
+ if not os.path.exists(root_path):
1261
+ os.makedirs(root_path)
1262
+ # data_path = os.path.join(root_path, "data")
1263
+ # pred_path = os.path.join(root_path, "pred")
1264
+ # if not os.path.exists(data_path):
1265
+ # os.makedirs(data_path)
1266
+ # if not os.path.exists(pred_path):
1267
+ # os.makedirs(pred_path)
1268
+ # np.save(os.path.join(data_path, args.result_dir.split('/')[-1]), data)
1269
+ # np.save(os.path.join(pred_path, "pred_results.npy", all_probs[:,1]))
1270
+ np.save(args.result_dir, data)
1271
+ ensemble_results = compute_metrics(eval_task, all_preds, out_label_ids, all_probs[:,1])
1272
+ logger.info("***** Ensemble results {} *****".format(prefix))
1273
+ for key in sorted(ensemble_results.keys()):
1274
+ logger.info(" %s = %s", key, str(ensemble_results[key]))
1275
+
1276
+
1277
+
1278
+
1279
+
1280
+ return results
1281
+
1282
+
1283
+ if __name__ == "__main__":
1284
+ main()
examples/run_pretrain.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
18
+ GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
19
+ using a masked language modeling (MLM) loss.
20
+ """
21
+
22
+
23
+ import argparse
24
+ import glob
25
+ import logging
26
+ import os
27
+ import pickle
28
+ import random
29
+ import re
30
+ import shutil
31
+ from typing import Dict, List, Tuple
32
+ from copy import deepcopy
33
+ from multiprocessing import Pool
34
+
35
+ import numpy as np
36
+ import torch
37
+ from torch.nn.utils.rnn import pad_sequence
38
+ from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
39
+ from torch.utils.data.distributed import DistributedSampler
40
+ from tqdm import tqdm, trange
41
+
42
+ from transformers import (
43
+ WEIGHTS_NAME,
44
+ AdamW,
45
+ BertConfig,
46
+ BertForMaskedLM,
47
+ BertTokenizer,
48
+ DNATokenizer,
49
+ CamembertConfig,
50
+ CamembertForMaskedLM,
51
+ CamembertTokenizer,
52
+ DistilBertConfig,
53
+ DistilBertForMaskedLM,
54
+ DistilBertTokenizer,
55
+ GPT2Config,
56
+ GPT2LMHeadModel,
57
+ GPT2Tokenizer,
58
+ OpenAIGPTConfig,
59
+ OpenAIGPTLMHeadModel,
60
+ OpenAIGPTTokenizer,
61
+ PreTrainedModel,
62
+ PreTrainedTokenizer,
63
+ RobertaConfig,
64
+ RobertaForMaskedLM,
65
+ RobertaTokenizer,
66
+ get_linear_schedule_with_warmup,
67
+ )
68
+
69
+
70
+ try:
71
+ from torch.utils.tensorboard import SummaryWriter
72
+ except ImportError:
73
+ from tensorboardX import SummaryWriter
74
+
75
+
76
+ logger = logging.getLogger(__name__)
77
+
78
+
79
+ MODEL_CLASSES = {
80
+ "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
81
+ "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
82
+ "dna": (BertConfig, BertForMaskedLM, DNATokenizer),
83
+ "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
84
+ "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
85
+ "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
86
+ "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
87
+ }
88
+
89
+ MASK_LIST = {
90
+ "3": [-1, 1],
91
+ "4": [-1, 1, 2],
92
+ "5": [-2, -1, 1, 2],
93
+ "6": [-2, -1, 1, 2, 3]
94
+ }
95
+
96
+
97
+ class TextDataset(Dataset):
98
+ def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
99
+ assert os.path.isfile(file_path)
100
+
101
+ block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)
102
+
103
+ directory, filename = os.path.split(file_path)
104
+ cached_features_file = os.path.join(
105
+ directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
106
+ )
107
+
108
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
109
+ logger.info("Loading features from cached file %s", cached_features_file)
110
+ with open(cached_features_file, "rb") as handle:
111
+ self.examples = pickle.load(handle)
112
+ else:
113
+ logger.info("Creating features from dataset file at %s", directory)
114
+
115
+ self.examples = []
116
+ with open(file_path, encoding="utf-8") as f:
117
+ text = f.read()
118
+
119
+ tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
120
+
121
+ for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
122
+ self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))
123
+ # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
124
+ # If your dataset is small, first you should loook for a bigger one :-) and second you
125
+ # can change this behavior by adding (model specific) padding.
126
+
127
+ logger.info("Saving features into cached file %s", cached_features_file)
128
+ with open(cached_features_file, "wb") as handle:
129
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
130
+
131
+ def __len__(self):
132
+ return len(self.examples)
133
+
134
+ def __getitem__(self, item):
135
+ return torch.tensor(self.examples[item], dtype=torch.long)
136
+
137
+ def convert_line_to_example(tokenizer, lines, max_length, add_special_tokens=True):
138
+ examples = tokenizer.batch_encode_plus(lines, add_special_tokens=add_special_tokens, max_length=max_length)["input_ids"]
139
+ return examples
140
+
141
+ class LineByLineTextDataset(Dataset):
142
+ def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
143
+ assert os.path.isfile(file_path)
144
+ # Here, we do not cache the features, operating under the assumption
145
+ # that we will soon use fast multithreaded tokenizers from the
146
+ # `tokenizers` repo everywhere =)
147
+ directory, filename = os.path.split(file_path)
148
+ cached_features_file = os.path.join(
149
+ directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
150
+ )
151
+
152
+ if os.path.exists(cached_features_file) and not args.overwrite_cache:
153
+ logger.info("Loading features from cached file %s", cached_features_file)
154
+ with open(cached_features_file, "rb") as handle:
155
+ self.examples = pickle.load(handle)
156
+ else:
157
+ logger.info("Creating features from dataset file at %s", file_path)
158
+
159
+ with open(file_path, encoding="utf-8") as f:
160
+ lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
161
+
162
+ if args.n_process == 1:
163
+ self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
164
+ else:
165
+ n_proc = args.n_process
166
+ p = Pool(n_proc)
167
+ indexes = [0]
168
+ len_slice = int(len(lines)/n_proc)
169
+ for i in range(1, n_proc+1):
170
+ if i != n_proc:
171
+ indexes.append(len_slice*(i))
172
+ else:
173
+ indexes.append(len(lines))
174
+ results = []
175
+ for i in range(n_proc):
176
+ results.append(p.apply_async(convert_line_to_example,[tokenizer, lines[indexes[i]:indexes[i+1]], block_size,]))
177
+ print(str(i) + " start")
178
+ p.close()
179
+ p.join()
180
+
181
+ self.examples = []
182
+ for result in results:
183
+ ids = result.get()
184
+ self.examples.extend(ids)
185
+
186
+ logger.info("Saving features into cached file %s", cached_features_file)
187
+ with open(cached_features_file, "wb") as handle:
188
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
189
+
190
+ def __len__(self):
191
+ return len(self.examples)
192
+
193
+ def __getitem__(self, i):
194
+ return torch.tensor(self.examples[i], dtype=torch.long)
195
+
196
+
197
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
198
+ file_path = args.eval_data_file if evaluate else args.train_data_file
199
+ if args.line_by_line:
200
+ return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
201
+ else:
202
+ return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
203
+
204
+
205
+ def set_seed(args):
206
+ random.seed(args.seed)
207
+ np.random.seed(args.seed)
208
+ torch.manual_seed(args.seed)
209
+ if args.n_gpu > 0:
210
+ torch.cuda.manual_seed_all(args.seed)
211
+
212
+
213
+ def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
214
+ ordering_and_checkpoint_path = []
215
+
216
+ glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
217
+
218
+ for path in glob_checkpoints:
219
+ if use_mtime:
220
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
221
+ else:
222
+ regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
223
+ if regex_match and regex_match.groups():
224
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
225
+
226
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
227
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
228
+ return checkpoints_sorted
229
+
230
+
231
+ def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
232
+ if not args.save_total_limit:
233
+ return
234
+ if args.save_total_limit <= 0:
235
+ return
236
+
237
+ # Check if we should delete older checkpoint(s)
238
+ checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
239
+ if len(checkpoints_sorted) <= args.save_total_limit:
240
+ return
241
+
242
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
243
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
244
+ for checkpoint in checkpoints_to_be_deleted:
245
+ logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
246
+ shutil.rmtree(checkpoint)
247
+
248
+
249
+
250
+
251
+ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
253
+
254
+ mask_list = MASK_LIST[tokenizer.kmer]
255
+
256
+ if tokenizer.mask_token is None:
257
+ raise ValueError(
258
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
259
+ )
260
+
261
+ labels = inputs.clone()
262
+ # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
263
+ probability_matrix = torch.full(labels.shape, args.mlm_probability)
264
+ special_tokens_mask = [
265
+ tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
266
+ ]
267
+ probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
268
+ if tokenizer._pad_token is not None:
269
+ padding_mask = labels.eq(tokenizer.pad_token_id)
270
+ probability_matrix.masked_fill_(padding_mask, value=0.0)
271
+
272
+ masked_indices = torch.bernoulli(probability_matrix).bool()
273
+
274
+ # change masked indices
275
+ masks = deepcopy(masked_indices)
276
+ for i, masked_index in enumerate(masks):
277
+ end = torch.where(probability_matrix[i]!=0)[0].tolist()[-1]
278
+ mask_centers = set(torch.where(masked_index==1)[0].tolist())
279
+ new_centers = deepcopy(mask_centers)
280
+ for center in mask_centers:
281
+ for mask_number in mask_list:
282
+ current_index = center + mask_number
283
+ if current_index <= end and current_index >= 1:
284
+ new_centers.add(current_index)
285
+ new_centers = list(new_centers)
286
+ masked_indices[i][new_centers] = True
287
+
288
+
289
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
290
+
291
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
292
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
293
+ inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
294
+
295
+ # 10% of the time, we replace masked input tokens with random word
296
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
297
+ random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
298
+ inputs[indices_random] = random_words[indices_random]
299
+
300
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
301
+ return inputs, labels
302
+
303
+
304
+ def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
305
+ """ Train the model """
306
+ if args.local_rank in [-1, 0]:
307
+ tb_writer = SummaryWriter()
308
+
309
+ args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
310
+
311
+ def collate(examples: List[torch.Tensor]):
312
+ if tokenizer._pad_token is None:
313
+ return pad_sequence(examples, batch_first=True)
314
+ return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
315
+
316
+ train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
317
+ train_dataloader = DataLoader(
318
+ train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
319
+ )
320
+
321
+ if args.max_steps > 0:
322
+ t_total = args.max_steps
323
+ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
324
+ else:
325
+ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
326
+
327
+ # Prepare optimizer and schedule (linear warmup and decay)
328
+ no_decay = ["bias", "LayerNorm.weight"]
329
+ optimizer_grouped_parameters = [
330
+ {
331
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
332
+ "weight_decay": args.weight_decay,
333
+ },
334
+ {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
335
+ ]
336
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.beta1,args.beta2))
337
+ scheduler = get_linear_schedule_with_warmup(
338
+ optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
339
+ )
340
+
341
+ # Check if saved optimizer or scheduler states exist
342
+ if (
343
+ args.model_name_or_path
344
+ and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
345
+ and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
346
+ ):
347
+ # Load in optimizer and scheduler states
348
+ optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
349
+ scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
350
+
351
+ if args.fp16:
352
+ try:
353
+ from apex import amp
354
+ except ImportError:
355
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
356
+ model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
357
+
358
+ # multi-gpu training (should be after apex fp16 initialization)
359
+ if args.n_gpu > 1:
360
+ model = torch.nn.DataParallel(model)
361
+
362
+ # Distributed training (should be after apex fp16 initialization)
363
+ if args.local_rank != -1:
364
+ model = torch.nn.parallel.DistributedDataParallel(
365
+ model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
366
+ )
367
+
368
+ # Train!
369
+ logger.info("***** Running training *****")
370
+ logger.info(" Num examples = %d", len(train_dataset))
371
+ logger.info(" Num Epochs = %d", args.num_train_epochs)
372
+ logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
373
+ logger.info(
374
+ " Total train batch size (w. parallel, distributed & accumulation) = %d",
375
+ args.train_batch_size
376
+ * args.gradient_accumulation_steps
377
+ * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
378
+ )
379
+ logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
380
+ logger.info(" Total optimization steps = %d", t_total)
381
+
382
+ global_step = 0
383
+ epochs_trained = 0
384
+ steps_trained_in_current_epoch = 0
385
+ # Check if continuing training from a checkpoint
386
+ if args.model_name_or_path and os.path.exists(args.model_name_or_path):
387
+ try:
388
+ # set global_step to gobal_step of last saved checkpoint from model path
389
+ checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
390
+ global_step = int(checkpoint_suffix)
391
+ epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
392
+ steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
393
+
394
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
395
+ logger.info(" Continuing training from epoch %d", epochs_trained)
396
+ logger.info(" Continuing training from global step %d", global_step)
397
+ logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
398
+ except ValueError:
399
+ logger.info(" Starting fine-tuning.")
400
+
401
+ tr_loss, logging_loss = 0.0, 0.0
402
+
403
+ model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
404
+ model_to_resize.resize_token_embeddings(len(tokenizer))
405
+
406
+ model.zero_grad()
407
+ train_iterator = trange(
408
+ epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
409
+ )
410
+ set_seed(args) # Added here for reproducibility
411
+ ids_set = {'0':0,'1':0,'2':0,'3':0,'4':0,'5':0,'6':0,'7':0,'8':0}
412
+ for _ in train_iterator:
413
+ epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
414
+ for step, batch in enumerate(epoch_iterator):
415
+
416
+ # Skip past any already trained steps if resuming training
417
+ if steps_trained_in_current_epoch > 0:
418
+ steps_trained_in_current_epoch -= 1
419
+ continue
420
+
421
+ inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
422
+ # print(inputs.shape)
423
+ # print(inputs)
424
+ # for i in range(len(inputs)):
425
+ # for j in range(len(inputs[i])):
426
+ # ids_set[str(int(inputs[i][j]))] += 1
427
+ # print(ids_set)
428
+ inputs = inputs.to(args.device)
429
+ labels = labels.to(args.device)
430
+ model.train()
431
+ outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
432
+ loss = outputs[0] # model outputs are always tuple in transformers (see doc)
433
+
434
+ if args.n_gpu > 1:
435
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
436
+ if args.gradient_accumulation_steps > 1:
437
+ loss = loss / args.gradient_accumulation_steps
438
+
439
+ if args.fp16:
440
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
441
+ scaled_loss.backward()
442
+ else:
443
+ loss.backward()
444
+
445
+ tr_loss += loss.item()
446
+ if (step + 1) % args.gradient_accumulation_steps == 0:
447
+ if args.fp16:
448
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
449
+ else:
450
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
451
+ optimizer.step()
452
+ scheduler.step() # Update learning rate schedule
453
+ model.zero_grad()
454
+ global_step += 1
455
+
456
+ if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
457
+ # Log metrics
458
+ if (
459
+ args.local_rank == -1 and args.evaluate_during_training
460
+ ): # Only evaluate when single GPU otherwise metrics may not average well
461
+ results = evaluate(args, model, tokenizer)
462
+ for key, value in results.items():
463
+ tb_writer.add_scalar("eval_{}".format(key), value, global_step)
464
+ tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
465
+ tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
466
+ logging_loss = tr_loss
467
+
468
+ if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
469
+ checkpoint_prefix = "checkpoint"
470
+ # Save model checkpoint
471
+ output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
472
+ os.makedirs(output_dir, exist_ok=True)
473
+ model_to_save = (
474
+ model.module if hasattr(model, "module") else model
475
+ ) # Take care of distributed/parallel training
476
+ model_to_save.save_pretrained(output_dir)
477
+ tokenizer.save_pretrained(output_dir)
478
+
479
+ torch.save(args, os.path.join(output_dir, "training_args.bin"))
480
+ logger.info("Saving model checkpoint to %s", output_dir)
481
+
482
+ _rotate_checkpoints(args, checkpoint_prefix)
483
+
484
+ torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
485
+ torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
486
+ logger.info("Saving optimizer and scheduler states to %s", output_dir)
487
+
488
+ if args.max_steps > 0 and global_step > args.max_steps:
489
+ epoch_iterator.close()
490
+ break
491
+ if args.max_steps > 0 and global_step > args.max_steps:
492
+ train_iterator.close()
493
+ break
494
+
495
+ if args.local_rank in [-1, 0]:
496
+ tb_writer.close()
497
+
498
+ return global_step, tr_loss / global_step
499
+
500
+
501
+ def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
502
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
503
+ eval_output_dir = args.output_dir
504
+
505
+ eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
506
+
507
+ if args.local_rank in [-1, 0]:
508
+ os.makedirs(eval_output_dir, exist_ok=True)
509
+
510
+ args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
511
+ # Note that DistributedSampler samples randomly
512
+
513
+ def collate(examples: List[torch.Tensor]):
514
+ if tokenizer._pad_token is None:
515
+ return pad_sequence(examples, batch_first=True)
516
+ return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
517
+
518
+ eval_sampler = SequentialSampler(eval_dataset)
519
+ eval_dataloader = DataLoader(
520
+ eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
521
+ )
522
+
523
+ # multi-gpu evaluate
524
+ if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
525
+ model = torch.nn.DataParallel(model)
526
+
527
+ # Eval!
528
+ logger.info("***** Running evaluation {} *****".format(prefix))
529
+ logger.info(" Num examples = %d", len(eval_dataset))
530
+ logger.info(" Batch size = %d", args.eval_batch_size)
531
+ eval_loss = 0.0
532
+ nb_eval_steps = 0
533
+ model.eval()
534
+
535
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
536
+ inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
537
+ inputs = inputs.to(args.device)
538
+ labels = labels.to(args.device)
539
+
540
+ with torch.no_grad():
541
+ outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
542
+ lm_loss = outputs[0]
543
+ eval_loss += lm_loss.mean().item()
544
+ nb_eval_steps += 1
545
+
546
+ eval_loss = eval_loss / nb_eval_steps
547
+ perplexity = torch.exp(torch.tensor(eval_loss))
548
+
549
+ result = {"perplexity": perplexity}
550
+
551
+ output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
552
+ with open(output_eval_file, "a") as writer:
553
+ logger.info("***** Eval results {} *****".format(prefix))
554
+ for key in sorted(result.keys()):
555
+ logger.info(" %s = %s", key, str(result[key]))
556
+ writer.write(str(float(perplexity)) + "\n")
557
+ # writer.write("%s = %s\n" % (key, str(result[key])))
558
+
559
+ return result
560
+
561
+
562
+ def main():
563
+ parser = argparse.ArgumentParser()
564
+
565
+ # Required parameters
566
+ parser.add_argument(
567
+ "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
568
+ )
569
+ parser.add_argument(
570
+ "--output_dir",
571
+ type=str,
572
+ required=True,
573
+ help="The output directory where the model predictions and checkpoints will be written.",
574
+ )
575
+ parser.add_argument(
576
+ "--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
577
+ )
578
+
579
+ # Other parameters
580
+ parser.add_argument(
581
+ "--eval_data_file",
582
+ default=None,
583
+ type=str,
584
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
585
+ )
586
+ parser.add_argument(
587
+ "--line_by_line",
588
+ action="store_true",
589
+ help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
590
+ )
591
+ parser.add_argument(
592
+ "--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
593
+ )
594
+ parser.add_argument(
595
+ "--model_name_or_path",
596
+ default=None,
597
+ type=str,
598
+ help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
599
+ )
600
+
601
+ parser.add_argument(
602
+ "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
603
+ )
604
+ parser.add_argument(
605
+ "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
606
+ )
607
+
608
+ parser.add_argument(
609
+ "--config_name",
610
+ default=None,
611
+ type=str,
612
+ help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
613
+ )
614
+ parser.add_argument(
615
+ "--tokenizer_name",
616
+ default=None,
617
+ type=str,
618
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
619
+ )
620
+ parser.add_argument(
621
+ "--cache_dir",
622
+ default=None,
623
+ type=str,
624
+ help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
625
+ )
626
+ parser.add_argument(
627
+ "--block_size",
628
+ default=-1,
629
+ type=int,
630
+ help="Optional input sequence length after tokenization."
631
+ "The training dataset will be truncated in block of this size for training."
632
+ "Default to the model max input length for single sentence inputs (take into account special tokens).",
633
+ )
634
+ parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
635
+ parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
636
+ parser.add_argument(
637
+ "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
638
+ )
639
+
640
+ parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
641
+ parser.add_argument(
642
+ "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
643
+ )
644
+ parser.add_argument(
645
+ "--gradient_accumulation_steps",
646
+ type=int,
647
+ default=1,
648
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
649
+ )
650
+ parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
651
+ parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
652
+ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
653
+ parser.add_argument("--beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
654
+ parser.add_argument("--beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
655
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
656
+ parser.add_argument(
657
+ "--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
658
+ )
659
+ parser.add_argument(
660
+ "--max_steps",
661
+ default=-1,
662
+ type=int,
663
+ help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
664
+ )
665
+ parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
666
+
667
+ parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
668
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
669
+ parser.add_argument(
670
+ "--save_total_limit",
671
+ type=int,
672
+ default=None,
673
+ help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
674
+ )
675
+ parser.add_argument(
676
+ "--eval_all_checkpoints",
677
+ action="store_true",
678
+ help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
679
+ )
680
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
681
+ parser.add_argument(
682
+ "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
683
+ )
684
+ parser.add_argument(
685
+ "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
686
+ )
687
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
688
+ parser.add_argument("--n_process", type=int, default=1, help="")
689
+
690
+ parser.add_argument(
691
+ "--fp16",
692
+ action="store_true",
693
+ help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
694
+ )
695
+ parser.add_argument(
696
+ "--fp16_opt_level",
697
+ type=str,
698
+ default="O1",
699
+ help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
700
+ "See details at https://nvidia.github.io/apex/amp.html",
701
+ )
702
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
703
+ parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
704
+ parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
705
+ args = parser.parse_args()
706
+
707
+ if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
708
+ raise ValueError(
709
+ "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
710
+ "flag (masked language modeling)."
711
+ )
712
+ if args.eval_data_file is None and args.do_eval:
713
+ raise ValueError(
714
+ "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
715
+ "or remove the --do_eval argument."
716
+ )
717
+ if args.should_continue:
718
+ sorted_checkpoints = _sorted_checkpoints(args)
719
+ if len(sorted_checkpoints) == 0:
720
+ raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
721
+ else:
722
+ args.model_name_or_path = sorted_checkpoints[-1]
723
+
724
+ if (
725
+ os.path.exists(args.output_dir)
726
+ and os.listdir(args.output_dir)
727
+ and args.do_train
728
+ and not args.overwrite_output_dir
729
+ ):
730
+ raise ValueError(
731
+ "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
732
+ args.output_dir
733
+ )
734
+ )
735
+
736
+ # Setup distant debugging if needed
737
+ if args.server_ip and args.server_port:
738
+ # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
739
+ import ptvsd
740
+
741
+ print("Waiting for debugger attach")
742
+ ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
743
+ ptvsd.wait_for_attach()
744
+
745
+ # Setup CUDA, GPU & distributed training
746
+ if args.local_rank == -1 or args.no_cuda:
747
+ device = torch.device("cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu")
748
+ args.n_gpu = torch.cuda.device_count()
749
+ else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
750
+ torch.cuda.set_device(args.local_rank)
751
+ device = torch.device("cuda", args.local_rank)
752
+ torch.distributed.init_process_group(backend="nccl")
753
+ args.n_gpu = 1
754
+ args.device = device
755
+
756
+ # Setup logging
757
+ logging.basicConfig(
758
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
759
+ datefmt="%m/%d/%Y %H:%M:%S",
760
+ level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
761
+ )
762
+ logger.warning(
763
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
764
+ args.local_rank,
765
+ device,
766
+ args.n_gpu,
767
+ bool(args.local_rank != -1),
768
+ args.fp16,
769
+ )
770
+
771
+ # Set seed
772
+ set_seed(args)
773
+
774
+ # Load pretrained model and tokenizer
775
+ if args.local_rank not in [-1, 0]:
776
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
777
+
778
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
779
+
780
+ if args.config_name:
781
+ config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
782
+ elif args.model_name_or_path:
783
+ config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
784
+ else:
785
+ config = config_class()
786
+
787
+
788
+ if args.tokenizer_name:
789
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
790
+ elif args.model_name_or_path:
791
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
792
+ else:
793
+ raise ValueError(
794
+ "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
795
+ "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
796
+ )
797
+
798
+ # text = "C G A T A T A G"
799
+ # print(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
800
+
801
+ if args.block_size <= 0:
802
+ args.block_size = tokenizer.max_len
803
+ # Our input block size will be the max possible for the model
804
+ else:
805
+ args.block_size = min(args.block_size, tokenizer.max_len)
806
+
807
+ if args.model_name_or_path:
808
+ model = model_class.from_pretrained(
809
+ args.model_name_or_path,
810
+ from_tf=bool(".ckpt" in args.model_name_or_path),
811
+ config=config,
812
+ cache_dir=args.cache_dir,
813
+ )
814
+ else:
815
+ logger.info("Training new model from scratch")
816
+ model = model_class(config=config)
817
+
818
+ model.to(args.device)
819
+
820
+ if args.local_rank == 0:
821
+ torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
822
+
823
+ logger.info("Training/evaluation parameters %s", args)
824
+
825
+ # Training
826
+ if args.do_train:
827
+ if args.local_rank not in [-1, 0]:
828
+ torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
829
+
830
+ train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
831
+
832
+ if args.local_rank == 0:
833
+ torch.distributed.barrier()
834
+
835
+ global_step, tr_loss = train(args, train_dataset, model, tokenizer)
836
+ logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
837
+
838
+ # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
839
+ if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
840
+ # Create output directory if needed
841
+ if args.local_rank in [-1, 0]:
842
+ os.makedirs(args.output_dir, exist_ok=True)
843
+
844
+ logger.info("Saving model checkpoint to %s", args.output_dir)
845
+ # Save a trained model, configuration and tokenizer using `save_pretrained()`.
846
+ # They can then be reloaded using `from_pretrained()`
847
+ model_to_save = (
848
+ model.module if hasattr(model, "module") else model
849
+ ) # Take care of distributed/parallel training
850
+ model_to_save.save_pretrained(args.output_dir)
851
+ tokenizer.save_pretrained(args.output_dir)
852
+
853
+ # Good practice: save your training arguments together with the trained model
854
+ torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
855
+
856
+ # Load a trained model and vocabulary that you have fine-tuned
857
+ model = model_class.from_pretrained(args.output_dir)
858
+ tokenizer = tokenizer_class.from_pretrained(args.output_dir)
859
+ model.to(args.device)
860
+
861
+ # Evaluation
862
+ results = {}
863
+ if args.do_eval and args.local_rank in [-1, 0]:
864
+ checkpoints = [args.output_dir]
865
+ if args.eval_all_checkpoints:
866
+ checkpoints = list(
867
+ os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
868
+ )
869
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
870
+ logger.info("Evaluate the following checkpoints: %s", checkpoints)
871
+ for checkpoint in checkpoints:
872
+ global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
873
+ prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
874
+
875
+ model = model_class.from_pretrained(checkpoint)
876
+ model.to(args.device)
877
+ result = evaluate(args, model, tokenizer, prefix=prefix)
878
+ result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
879
+ results.update(result)
880
+
881
+ return results
882
+
883
+
884
+ if __name__ == "__main__":
885
+ main()
examples/run_pretrain.sh.save ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Launch with 4 processes (one for each GPU)
2
+ export KMER=6
3
+ export TRAIN_FILE=/home/n5huang/dna_token/output_tokens/all_tokenized_train.txt
4
+ export TEST_FILE=/home/n5huang/dna_token/output_tokens/all_tokenized_val.txt
5
+ export SOURCE=PATH_TO_DNABERT_REPO
6
+ export OUTPUT_PATH=output$KMER
7
+
8
+ python run_pretrain.py \
9
+ --output_dir $OUTPUT_PATH \
10
+ --model_type=dna \
11
+ --tokenizer_name=dna$KMER \
12
+ --config_name=$SOURCE/src/transformers/dnabert-config/bert-config-$KMER/config.json \
13
+ --do_train \
14
+ --train_data_file=$TRAIN_FILE \
15
+ --do_eval \
16
+ --eval_data_file=$TEST_FILE \
17
+ --mlm \
18
+ --gradient_accumulation_steps 7 \ # ADJUSTED for 4 GPUs: (10 * 7 * 4 = 280)
19
+ --per_gpu_train_batch_size 10 \
20
+ --per_gpu_eval_batch_size 6 \
21
+ --save_steps 500 \
22
+ --save_total_limit 20 \
23
+ --max_steps 10000 \ # Recommended starting point for a custom dataset
24
+ --evaluate_during_training \
25
+ --logging_steps 500 \
26
+ --line_by_line \
27
+ --learning_rate 4e-4 \
28
+ --block_size 512 \
29
+ --adam_epsilon 1e-6 \
30
+ --weight_decay 0.01 \
31
+ --beta1 0.9 \
32
+ --beta2 0.98 \
33
+ --mlm_probability 0.025 \
34
+ --warmup_steps 10000 \
35
+ --overwrite_output_dir \
36
+ --n_process 24
examples/sample_data/ft/6/dev.tsv ADDED
The diff for this file is too large to render. See raw diff
 
examples/sample_data/ft/6/train.tsv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a824c48fe4b7cd1cde690882f9cd50dd628165e168453a714065d21a9c9bc7c
3
+ size 21847066
examples/sample_data/pre/6_3k.txt ADDED
The diff for this file is too large to render. See raw diff
 
examples/save_static_embeddings.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import os
4
+ from transformers import BertModel, BertConfig, DNATokenizer, BertForMaskedLM
5
+
6
+ # --- CONFIGURATION ---
7
+ OUTPUT_FOLDER = "6mer_pretrain_emb_adaptive"
8
+ OUTPUT_FILENAME = "static_adaptive_embed.npy"
9
+ CHECKPOINT_PATH = "/data/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000/"
10
+
11
+ if not CHECKPOINT_PATH:
12
+ raise EnvironmentError("MODEL_DIR environment variable is not set.")
13
+
14
+ # --- DUMMY MODEL CLASSES (Needed for the code structure) ---
15
+ MODEL_CLASSES = {
16
+ "dna": (BertConfig, BertForMaskedLM, DNATokenizer),
17
+ }
18
+
19
+ # --- CUSTOM LOADING FUNCTION (Modified to return BertModel for clean embeddings) ---
20
+ def loadmodel(model_dir):
21
+ config_class, _, tokenizer_class = MODEL_CLASSES['dna']
22
+
23
+ # Load Config
24
+ config = config_class.from_pretrained(model_dir)
25
+
26
+ # Explicitly load the BASE BERT MODEL (BertModel) to access the embedding layer
27
+ model = BertModel.from_pretrained(model_dir, config=config)
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model.to(device)
31
+ model.eval()
32
+
33
+ # Load Tokenizer (using custom environment variables)
34
+ #tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")}
35
+ #tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}}
36
+ tokenizer = tokenizer_class.from_pretrained(model_dir)
37
+
38
+ return model, tokenizer
39
+
40
+ # --- MAIN EXECUTION ---
41
+ if __name__ == "__main__":
42
+ # Load the model and tokenizer
43
+ print("Starting model and tokenizer load...")
44
+ model, tokenizer = loadmodel(CHECKPOINT_PATH)
45
+ print(f"Model and Tokenizer loaded successfully. Vocab size: {len(tokenizer)}")
46
+
47
+ # 1. Extract the static embedding layer
48
+ # This matrix contains the vector for every token ID (4101 tokens x 768 dimensions)
49
+ embedding_layer = model.get_input_embeddings()
50
+ print(embedding_layer.weight.shape)
51
+
52
+ # 2. Extract the weights (the actual NumPy array)
53
+ # Detach from GPU and convert to NumPy
54
+ static_embeddings_tensor = embedding_layer.weight.data.cpu()
55
+ static_embeddings_array = static_embeddings_tensor.numpy()
56
+
57
+ print(f"\nExtracted embedding tensor size: {static_embeddings_tensor.size()}")
58
+ print(f"Extracted NumPy array shape: {static_embeddings_array.shape}")
59
+
60
+ # 3. Save the Embeddings
61
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
62
+ output_path = os.path.join(OUTPUT_FOLDER, OUTPUT_FILENAME)
63
+ np.save(output_path, static_embeddings_array)
64
+
65
+ print(f"\n✅ Successfully saved static embeddings to: {output_path}")
examples/scripts/run_mut.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export MODEL_PATH=/gluster/zhihan/backup/dna/690/6
3
+ for model in $(ls $MODEL_PATH)
4
+ do
5
+ export MODEL="$model"
6
+ export CHECKPOINT=$(ls $MODEL_PATH/$MODEL | head -1)
7
+ if [ ! -d "/gluster/zhihan/DNABERT/examples/data/ori_results/$MODEL" ]
8
+ then
9
+ python run_finetune.py \
10
+ --model_type dna \
11
+ --tokenizer_name=dna6 \
12
+ --model_name_or_path $MODEL_PATH/$MODEL/$CHECKPOINT \
13
+ --task_name dnaprom \
14
+ --do_predict \
15
+ --data_dir /gluster/zhihan/DNABERT/examples/data/ori \
16
+ --max_seq_length 110 \
17
+ --per_gpu_pred_batch_size=256 \
18
+ --output_dir $MODEL_PATH/$MODEL/$CHECKPOINT \
19
+ --predict_dir /gluster/zhihan/DNABERT/examples/data/ori_results/$MODEL \
20
+ --fp16 \
21
+ --n_process 96
22
+ fi
23
+ done
24
+
25
+ for model in $(ls $MODEL_PATH)
26
+ do
27
+ export MODEL="$model"
28
+ export CHECKPOINT=$(ls $MODEL_PATH/$MODEL | head -1)
29
+ if [ ! -d "/gluster/zhihan/DNABERT/examples/data/mut_results/$MODEL" ]
30
+ then
31
+ python run_finetune.py \
32
+ --model_type dna \
33
+ --tokenizer_name=dna6 \
34
+ --model_name_or_path $MODEL_PATH/$MODEL/$CHECKPOINT \
35
+ --task_name dnaprom \
36
+ --do_predict \
37
+ --data_dir /gluster/zhihan/DNABERT/examples/data/mut \
38
+ --max_seq_length 110 \
39
+ --per_gpu_pred_batch_size=256 \
40
+ --output_dir $MODEL_PATH/$MODEL/$CHECKPOINT \
41
+ --predict_dir /gluster/zhihan/DNABERT/examples/data/mut_results/$MODEL \
42
+ --fp16 \
43
+ --n_process 96
44
+ fi
45
+ done
examples/scripts/uce.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_PATH=/home/zhihan/6
2
+ # for cp in $(ls $MODEL_PATH)
3
+ # do
4
+ # cd $MODEL_PATH/$cp
5
+ # mv checkpoin* checkpoint-0
6
+ # done
7
+
8
+ for model in $(ls $MODEL_PATH | head -345)
9
+ do
10
+ export MODEL="$model"
11
+ export CHECKPOINT=$(ls $MODEL_PATH/$MODEL)
12
+ CUDA_VISIBLE_DEVICES=0 python run_finetune.py \
13
+ --model_type dna \
14
+ --tokenizer_name=dna6 \
15
+ --model_name_or_path $MODEL_PATH/$MODEL/$CHECKPOINT \
16
+ --task_name dnaprom \
17
+ --do_visualize \
18
+ --visualize_data_dir /home/zhihan/data/uce/processed/ \
19
+ --visualize_models 6 \
20
+ --data_dir /home/zhihan/data/uce/processed/ \
21
+ --max_seq_length 110 \
22
+ --per_gpu_pred_batch_size=16 \
23
+ --output_dir $MODEL_PATH/$MODEL/$CHECKPOINT \
24
+ --predict_dir /home/zhihan/data/uce/results/$MODEL \
25
+ --n_process 24
26
+ done
examples/visualize.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ import argparse
5
+ import os
6
+ import numpy as np
7
+
8
+ from transformers import BertTokenizer, BertModel, DNATokenizer
9
+ from process_pretrain_data import get_kmer_sentence
10
+
11
+
12
+ def format_attention(attention):
13
+ squeezed = []
14
+ for layer_attention in attention:
15
+ # 1 x num_heads x seq_len x seq_len
16
+ if len(layer_attention.shape) != 4:
17
+ raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
18
+ "output_attentions=True when initializing your model.")
19
+ squeezed.append(layer_attention.squeeze(0))
20
+ # num_layers x num_heads x seq_len x seq_len
21
+ return torch.stack(squeezed)
22
+
23
+ def get_attention_dna(model, tokenizer, sentence_a, start, end):
24
+ inputs = tokenizer.encode_plus(sentence_a, sentence_b=None, return_tensors='pt', add_special_tokens=True)
25
+ input_ids = inputs['input_ids']
26
+ attention = model(input_ids)[-1]
27
+ input_id_list = input_ids[0].tolist() # Batch index 0
28
+ tokens = tokenizer.convert_ids_to_tokens(input_id_list)
29
+ attn = format_attention(attention)
30
+ attn_score = []
31
+ for i in range(1, len(tokens)-1):
32
+ attn_score.append(float(attn[start:end+1,:,0,i].sum()))
33
+ return attn_score
34
+
35
+ def get_real_score(attention_scores, kmer, metric):
36
+ counts = np.zeros([len(attention_scores)+kmer-1])
37
+ real_scores = np.zeros([len(attention_scores)+kmer-1])
38
+
39
+ if metric == "mean":
40
+ for i, score in enumerate(attention_scores):
41
+ for j in range(kmer):
42
+ counts[i+j] += 1.0
43
+ real_scores[i+j] += score
44
+
45
+ real_scores = real_scores/counts
46
+ else:
47
+ pass
48
+
49
+ return real_scores
50
+
51
+ SEQUENCE = "TGCCTGGCTTTTTGTAATTTTTGAAGAGACGGGGTTTTGCCATGATG"
52
+
53
+ def Visualize(args):
54
+ if args.kmer == 0:
55
+ KMER_LIST = [3,4,5,6]
56
+
57
+ for kmer in KMER_LIST:
58
+ tokenizer_name = 'dna' + str(kmer)
59
+ model_path = os.path.join(args.model_path, str(kmer))
60
+ model = BertModel.from_pretrained(model_path, output_attentions=True)
61
+ tokenizer = DNATokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
62
+ raw_sentence = args.sequence if args.sequence else SEQUENCE
63
+ sentence_a = get_kmer_sentence(raw_sentence, kmer)
64
+ tokens = sentence_a.split()
65
+
66
+ attention = get_attention_dna(model, tokenizer, sentence_a, start=args.start_layer, end=args.end_layer)
67
+ attention_scores = np.array(attention).reshape(np.array(attention).shape[0],1)
68
+ # attention_scores[0] = 0
69
+
70
+ real_scores = get_real_score(attention_scores, kmer, args.metric)
71
+ real_scores = real_scores / np.linalg.norm(real_scores)
72
+
73
+ if kmer != KMER_LIST[0]:
74
+ scores += real_scores.reshape(1, real_scores.shape[0])
75
+ else:
76
+ scores = real_scores.reshape(1, real_scores.shape[0])
77
+
78
+ else:
79
+ # load model and calculate attention
80
+ tokenizer_name = 'dna' + str(args.kmer)
81
+ model_path = args.model_path
82
+ model = BertModel.from_pretrained(model_path, output_attentions=True)
83
+ tokenizer = DNATokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
84
+ raw_sentence = args.sequence if args.sequence else SEQUENCE
85
+ sentence_a = get_kmer_sentence(raw_sentence, args.kmer)
86
+ tokens = sentence_a.split()
87
+
88
+ attention = get_attention_dna(model, tokenizer, sentence_a, start=args.start_layer, end=args.end_layer)
89
+ attention_scores = np.array(attention).reshape(np.array(attention).shape[0],1)
90
+ # attention_scores[0] = 0
91
+
92
+ real_scores = get_real_score(attention_scores, args.kmer, args.metric)
93
+ scores = real_scores.reshape(1, real_scores.shape[0])
94
+
95
+ ave = np.sum(scores)/scores.shape[1]
96
+ print(ave)
97
+ print(scores)
98
+
99
+ # plot
100
+ sns.set()
101
+ ax = sns.heatmap(scores, cmap='YlGnBu', vmin=0)
102
+ plt.show()
103
+
104
+
105
+
106
+
107
+ def main():
108
+ parser = argparse.ArgumentParser()
109
+ parser.add_argument(
110
+ "--kmer",
111
+ default=0,
112
+ type=int,
113
+ help="K-mer",
114
+ )
115
+ parser.add_argument(
116
+ "--model_path",
117
+ default="/home/zhihan/dna/dna-transformers/examples/ft/690/p53-small/TAp73beta/3/",
118
+ type=str,
119
+ help="The path of the finetuned model",
120
+ )
121
+ parser.add_argument(
122
+ "--start_layer",
123
+ default=11,
124
+ type=int,
125
+ help="Which layer to start",
126
+ )
127
+ parser.add_argument(
128
+ "--end_layer",
129
+ default=11,
130
+ type=int,
131
+ help="which layer to end",
132
+ )
133
+ parser.add_argument(
134
+ "--metric",
135
+ default="mean",
136
+ type=str,
137
+ help="the metric used for integrate predicted kmer result to real result",
138
+ )
139
+ parser.add_argument(
140
+ "--sequence",
141
+ default=None,
142
+ type=str,
143
+ help="the sequence for visualize",
144
+ )
145
+
146
+ args = parser.parse_args()
147
+ Visualize(args)
148
+
149
+
150
+
151
+ if __name__ == "__main__":
152
+ main()
motif/find_motifs.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### ::: DNABERT-viz find motifs ::: ####
2
+
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ import argparse
7
+ import motif_utils as utils
8
+
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument(
13
+ "--data_dir",
14
+ default=None,
15
+ type=str,
16
+ required=True,
17
+ help="The input data dir. Should contain the sequence+label .tsv files (or other data files) for the task.",
18
+ )
19
+
20
+ parser.add_argument(
21
+ "--predict_dir",
22
+ default=None,
23
+ type=str,
24
+ required=True,
25
+ help="Path where the attention scores were saved. Should contain both pred_results.npy and atten.npy",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--window_size",
30
+ default=24,
31
+ type=int,
32
+ help="Specified window size to be final motif length",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--min_len",
37
+ default=5,
38
+ type=int,
39
+ help="Specified minimum length threshold for contiguous region",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--pval_cutoff",
44
+ default=0.005,
45
+ type=float,
46
+ help="Cutoff FDR/p-value to declare statistical significance",
47
+ )
48
+
49
+ parser.add_argument(
50
+ "--min_n_motif",
51
+ default=3,
52
+ type=int,
53
+ help="Minimum instance inside motif to be filtered",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "--align_all_ties",
58
+ action='store_true',
59
+ help="Whether to keep all best alignments when ties encountered",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--save_file_dir",
64
+ default='.',
65
+ type=str,
66
+ help="Path to save outputs",
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--verbose",
71
+ action='store_true',
72
+ help="Verbosity controller",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--return_idx",
77
+ action='store_true',
78
+ help="Whether the indices of the motifs are only returned",
79
+ )
80
+
81
+ # TODO: add the conditions
82
+ args = parser.parse_args()
83
+
84
+ atten_scores = np.load(os.path.join(args.predict_dir,"atten.npy"))
85
+ pred = np.load(os.path.join(args.predict_dir,"pred_results.npy"))
86
+ dev = pd.read_csv(os.path.join(args.data_dir,"dev.tsv"),sep='\t',header=0)
87
+ dev.columns = ['sequence','label']
88
+ dev['seq'] = dev['sequence'].apply(utils.kmer2seq)
89
+ dev_pos = dev[dev['label'] == 1]
90
+ dev_neg = dev[dev['label'] == 0]
91
+ pos_atten_scores = atten_scores[dev_pos.index.values]
92
+ neg_atten_scores = atten_scores[dev_neg.index.values]
93
+ assert len(dev_pos) == len(pos_atten_scores)
94
+
95
+ # run motif analysis
96
+ merged_motif_seqs = utils.motif_analysis(dev_pos['seq'],
97
+ dev_neg['seq'],
98
+ pos_atten_scores,
99
+ window_size = args.window_size,
100
+ min_len = args.min_len,
101
+ pval_cutoff = args.pval_cutoff,
102
+ min_n_motif = args.min_n_motif,
103
+ align_all_ties = args.align_all_ties,
104
+ save_file_dir = args.save_file_dir,
105
+ verbose = args.verbose,
106
+ return_idx = args.return_idx
107
+ )
108
+
109
+ if __name__ == "__main__":
110
+ main()
111
+
112
+
motif/motif_utils.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### ::: utils for DNABERT-viz motif search ::: ####
2
+
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ def kmer2seq(kmers):
8
+ """
9
+ Convert kmers to original sequence
10
+
11
+ Arguments:
12
+ kmers -- str, kmers separated by space.
13
+
14
+ Returns:
15
+ seq -- str, original sequence.
16
+
17
+ """
18
+ kmers_list = kmers.split(" ")
19
+ bases = [kmer[0] for kmer in kmers_list[0:-1]]
20
+ bases.append(kmers_list[-1])
21
+ seq = "".join(bases)
22
+ assert len(seq) == len(kmers_list) + len(kmers_list[0]) - 1
23
+ return seq
24
+
25
+ def seq2kmer(seq, k):
26
+ """
27
+ Convert original sequence to kmers
28
+
29
+ Arguments:
30
+ seq -- str, original sequence.
31
+ k -- int, kmer of length k specified.
32
+
33
+ Returns:
34
+ kmers -- str, kmers separated by space
35
+
36
+ """
37
+ kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
38
+ kmers = " ".join(kmer)
39
+ return kmers
40
+
41
+ def contiguous_regions(condition, len_thres=5):
42
+ """
43
+ Modified from and credit to: https://stackoverflow.com/a/4495197/3751373
44
+ Finds contiguous True regions of the boolean array "condition". Returns
45
+ a 2D array where the first column is the start index of the region and the
46
+ second column is the end index.
47
+
48
+ Arguments:
49
+ condition -- custom conditions to filter/select high attention
50
+ (list of boolean arrays)
51
+
52
+ Keyword arguments:
53
+ len_thres -- int, specified minimum length threshold for contiguous region
54
+ (default 5)
55
+
56
+ Returns:
57
+ idx -- Index of contiguous regions in sequence
58
+
59
+ """
60
+
61
+ # Find the indicies of changes in "condition"
62
+ d = np.diff(condition)
63
+ idx, = d.nonzero()
64
+
65
+ # We need to start things after the change in "condition". Therefore,
66
+ # we'll shift the index by 1 to the right.
67
+ idx += 1
68
+
69
+ if condition[0]:
70
+ # If the start of condition is True prepend a 0
71
+ idx = np.r_[0, idx]
72
+
73
+ if condition[-1]:
74
+ # If the end of condition is True, append the length of the array
75
+ idx = np.r_[idx, condition.size] # Edit
76
+
77
+ # Reshape the result into two columns
78
+ idx.shape = (-1,2)
79
+
80
+ # eliminate those not satisfying length of threshold
81
+ idx = idx[np.argwhere((idx[:,1]-idx[:,0])>=len_thres).flatten()]
82
+ return idx
83
+
84
+ def find_high_attention(score, min_len=5, **kwargs):
85
+ """
86
+ With an array of attention scores as input, finds contiguous high attention
87
+ sub-regions indices having length greater than min_len.
88
+
89
+ Arguments:
90
+ score -- numpy array of attention scores for a sequence
91
+
92
+ Keyword arguments:
93
+ min_len -- int, specified minimum length threshold for contiguous region
94
+ (default 5)
95
+ **kwargs -- other input arguments:
96
+ cond -- custom conditions to filter/select high attention
97
+ (list of boolean arrays)
98
+
99
+ Returns:
100
+ motif_regions -- indices of high attention regions in sequence
101
+
102
+ """
103
+
104
+ cond1 = (score > np.mean(score))
105
+ cond2 = (score > 10*np.min(score))
106
+ cond = [cond1, cond2]
107
+
108
+ cond = list(map(all, zip(*cond)))
109
+
110
+ if 'cond' in kwargs: # if input custom conditions, use them
111
+ cond = kwargs['cond']
112
+ if any(isinstance(x, list) for x in cond): # if input contains multiple conditions
113
+ cond = list(map(all, zip(*cond)))
114
+
115
+ cond = np.asarray(cond)
116
+
117
+
118
+ # find important contiguous region with high attention
119
+ motif_regions = contiguous_regions(cond,min_len)
120
+
121
+ return motif_regions
122
+
123
+ def count_motif_instances(seqs, motifs, allow_multi_match=False):
124
+ """
125
+ Use Aho-Corasick algorithm for efficient multi-pattern matching
126
+ between input sequences and motif patterns to obtain counts of instances.
127
+
128
+ Arguments:
129
+ seqs -- list, numpy array or pandas series of DNA sequences
130
+ motifs -- list, numpy array or pandas series, a collection of motif patterns
131
+ to be matched to seqs
132
+
133
+ Keyword arguments:
134
+ allow_multi_match -- bool, whether to allow for counting multiple matchs (default False)
135
+
136
+ Returns:
137
+ motif_count -- count of motif instances (int)
138
+
139
+ """
140
+ import ahocorasick
141
+ from operator import itemgetter
142
+
143
+ motif_count = {}
144
+
145
+ A = ahocorasick.Automaton()
146
+ for idx, key in enumerate(motifs):
147
+ A.add_word(key, (idx, key))
148
+ motif_count[key] = 0
149
+ A.make_automaton()
150
+
151
+ for seq in seqs:
152
+ matches = sorted(map(itemgetter(1), A.iter(seq)))
153
+ matched_seqs = []
154
+ for match in matches:
155
+ match_seq = match[1]
156
+ assert match_seq in motifs
157
+ if allow_multi_match:
158
+ motif_count[match_seq] += 1
159
+ else: # for a particular seq, count only once if multiple matches were found
160
+ if match_seq not in matched_seqs:
161
+ motif_count[match_seq] += 1
162
+ matched_seqs.append(match_seq)
163
+
164
+ return motif_count
165
+
166
+ def motifs_hypergeom_test(pos_seqs, neg_seqs, motifs, p_adjust = 'fdr_bh', alpha = 0.05, verbose=False,
167
+ allow_multi_match=False, **kwargs):
168
+ """
169
+ Perform hypergeometric test to find significantly enriched motifs in positive sequences.
170
+ Returns a list of adjusted p-values.
171
+
172
+ Arguments:
173
+ pos_seqs -- list, numpy array or pandas series of positive DNA sequences
174
+ neg_seqs -- list, numpy array or pandas series of negative DNA sequences
175
+ motifs -- list, numpy array or pandas series, a collection of motif patterns
176
+ to be matched to seqs
177
+
178
+ Keyword arguments:
179
+ p_adjust -- method used to correct for multiple testing problem. Options are same as
180
+ statsmodels.stats.multitest (default 'fdr_bh')
181
+ alpha -- cutoff FDR/p-value to declare statistical significance (default 0.05)
182
+ verbose -- verbosity argument (default False)
183
+ allow_multi_match -- bool, whether to allow for counting multiple matchs (default False)
184
+
185
+ Returns:
186
+ pvals -- a list of p-values.
187
+
188
+ """
189
+ from scipy.stats import hypergeom
190
+ import statsmodels.stats.multitest as multi
191
+
192
+
193
+ pvals = []
194
+ N = len(pos_seqs) + len(neg_seqs)
195
+ K = len(pos_seqs)
196
+ motif_count_all = count_motif_instances(pos_seqs+neg_seqs, motifs, allow_multi_match=allow_multi_match)
197
+ motif_count_pos = count_motif_instances(pos_seqs, motifs, allow_multi_match=allow_multi_match)
198
+
199
+ for motif in motifs:
200
+ n = motif_count_all[motif]
201
+ x = motif_count_pos[motif]
202
+ pval = hypergeom.sf(x-1, N, K, n)
203
+ if verbose:
204
+ if pval < 1e-5:
205
+ print("motif {}: N={}; K={}; n={}; x={}; p={}".format(motif, N, K, n, x, pval))
206
+ # pvals[motif] = pval
207
+ pvals.append(pval)
208
+
209
+ # adjust p-value
210
+ if p_adjust is not None:
211
+ pvals = list(multi.multipletests(pvals,alpha=alpha,method=p_adjust)[1])
212
+ return pvals
213
+
214
+ def filter_motifs(pos_seqs, neg_seqs, motifs, cutoff=0.05, return_idx=False, **kwargs):
215
+ """
216
+ Wrapper function for returning the actual motifs that passed the hypergeometric test.
217
+
218
+ Arguments:
219
+ pos_seqs -- list, numpy array or pandas series of positive DNA sequences
220
+ neg_seqs -- list, numpy array or pandas series of negative DNA sequences
221
+ motifs -- list, numpy array or pandas series, a collection of motif patterns
222
+ to be matched to seqs
223
+
224
+ Keyword arguments:
225
+ cutoff -- cutoff FDR/p-value to declare statistical significance. (default 0.05)
226
+ return_idx -- whether the indices of the motifs are only returned. (default False)
227
+ **kwargs -- other input arguments
228
+
229
+ Returns:
230
+ list of filtered motifs (or indices of the motifs)
231
+
232
+ """
233
+ pvals = motifs_hypergeom_test(pos_seqs, neg_seqs, motifs, **kwargs)
234
+ if return_idx:
235
+ return [i for i, pval in enumerate(pvals) if pval < cutoff]
236
+ else:
237
+ return [motifs[i] for i, pval in enumerate(pvals) if pval < cutoff]
238
+
239
+ def merge_motifs(motif_seqs, min_len=5, align_all_ties=True, **kwargs):
240
+ """
241
+ Function to merge similar motifs in input motif_seqs.
242
+
243
+ First sort keys of input motif_seqs based on length. For each query motif with length
244
+ guaranteed to >= key motif, perform pairwise alignment between them.
245
+
246
+ If can be aligned, find out best alignment among all combinations, then adjust start
247
+ and end position of high attention region based on left/right offsets calculated by
248
+ alignment of the query and key motifs.
249
+
250
+ If cannot be aligned with any existing key motifs, add to the new dict as new key motif.
251
+
252
+ Returns a new dict containing merged motifs.
253
+
254
+ Arguments:
255
+ motif_seqs -- nested dict, with the following structure:
256
+ {motif: {seq_idx: idx, atten_region_pos: (start, end)}}
257
+ where seq_idx indicates indices of pos_seqs containing a motif, and
258
+ atten_region_pos indicates where the high attention region is located.
259
+
260
+ Keyword arguments:
261
+ min_len -- int, specified minimum length threshold for contiguous region
262
+ (default 5)
263
+
264
+ align_all_ties -- bool, whether to keep all best alignments when ties encountered (default True)
265
+
266
+ **kwargs -- other input arguments, may include:
267
+ - cond: custom condition used to declare successful alignment.
268
+ default is score > max of (min_len -1) and (1/2 times min length of two motifs aligned)
269
+
270
+ Returns:
271
+ merged_motif_seqs -- nested dict with same structure as `motif_seqs`
272
+
273
+ """
274
+
275
+ from Bio import Align
276
+
277
+ ### TODO: modify algorithm to improve efficiency later
278
+ aligner = Align.PairwiseAligner()
279
+ aligner.internal_gap_score = -10000.0 # prohibit internal gaps
280
+
281
+ merged_motif_seqs = {}
282
+ for motif in sorted(motif_seqs, key=len): # query motif
283
+ if not merged_motif_seqs: # if empty
284
+ merged_motif_seqs[motif] = motif_seqs[motif] # add first one
285
+ else: # not empty, then compare and see if can be merged
286
+ # first create all alignment scores, to find out max
287
+ alignments = []
288
+ key_motifs = []
289
+ for key_motif in merged_motif_seqs.keys(): # key motif
290
+ if motif != key_motif: # do not attempt to align to self
291
+ # first is query, second is key within new dict
292
+ # first is guaranteed to be length >= second after sorting keys
293
+ alignment=aligner.align(motif, key_motif)[0]
294
+
295
+ # condition to declare successful alignment
296
+ cond = max((min_len -1), 0.5 * min(len(motif), len(key_motif)))
297
+
298
+ if 'cond' in kwargs:
299
+ cond = kwargs['cond'] # override
300
+
301
+ if alignment.score >= cond: # exists key that can align
302
+ alignments.append(alignment)
303
+ key_motifs.append(key_motif)
304
+
305
+ if alignments: # if aligned, find out alignment with maximum score and proceed
306
+ best_score = max(alignments, key=lambda alignment: alignment.score)
307
+ best_idx = [i for i, score in enumerate(alignments) if score == best_score]
308
+
309
+ if align_all_ties:
310
+ for i in best_idx:
311
+ alignment = alignments[i]
312
+ key_motif = key_motifs[i]
313
+
314
+ # calculate offset to be added/subtracted from atten_region_pos
315
+ left_offset = alignment.aligned[0][0][0] - alignment.aligned[1][0][0] # always query - key
316
+ if (alignment.aligned[0][0][1] <= len(motif)) & \
317
+ (alignment.aligned[1][0][1] == len(key_motif)): # inside
318
+ right_offset = len(motif) - alignment.aligned[0][0][1]
319
+ elif (alignment.aligned[0][0][1] == len(motif)) & \
320
+ (alignment.aligned[1][0][1] < len(key_motif)): # left shift
321
+ right_offset = alignment.aligned[1][0][1] - len(key_motif)
322
+ elif (alignment.aligned[0][0][1] < len(motif)) & \
323
+ (alignment.aligned[1][0][1] == len(key_motif)): # right shift
324
+ right_offset = len(motif) - alignment.aligned[0][0][1]
325
+
326
+ # add seq_idx back to new merged dict
327
+ merged_motif_seqs[key_motif]['seq_idx'].extend(motif_seqs[motif]['seq_idx'])
328
+
329
+ # calculate new atten_region_pos after adding/subtracting offset
330
+ new_atten_region_pos = [(pos[0]+left_offset, pos[1]-right_offset) \
331
+ for pos in motif_seqs[motif]['atten_region_pos']]
332
+ merged_motif_seqs[key_motif]['atten_region_pos'].extend(new_atten_region_pos)
333
+
334
+ else:
335
+ alignment = alignments[best_idx[0]]
336
+ key_motif = key_motifs[best_idx[0]]
337
+
338
+ # calculate offset to be added/subtracted from atten_region_pos
339
+ left_offset = alignment.aligned[0][0][0] - alignment.aligned[1][0][0] # always query - key
340
+ if (alignment.aligned[0][0][1] <= len(motif)) & \
341
+ (alignment.aligned[1][0][1] == len(key_motif)): # inside
342
+ right_offset = len(motif) - alignment.aligned[0][0][1]
343
+ elif (alignment.aligned[0][0][1] == len(motif)) & \
344
+ (alignment.aligned[1][0][1] < len(key_motif)): # left shift
345
+ right_offset = alignment.aligned[1][0][1] - len(key_motif)
346
+ elif (alignment.aligned[0][0][1] < len(motif)) & \
347
+ (alignment.aligned[1][0][1] == len(key_motif)): # right shift
348
+ right_offset = len(motif) - alignment.aligned[0][0][1]
349
+
350
+ # add seq_idx back to new merged dict
351
+ merged_motif_seqs[key_motif]['seq_idx'].extend(motif_seqs[motif]['seq_idx'])
352
+
353
+ # calculate new atten_region_pos after adding/subtracting offset
354
+ new_atten_region_pos = [(pos[0]+left_offset, pos[1]-right_offset) \
355
+ for pos in motif_seqs[motif]['atten_region_pos']]
356
+ merged_motif_seqs[key_motif]['atten_region_pos'].extend(new_atten_region_pos)
357
+
358
+ else: # cannot align to anything, add to new dict as independent key
359
+ merged_motif_seqs[motif] = motif_seqs[motif] # add new one
360
+
361
+ return merged_motif_seqs
362
+
363
+
364
+ def make_window(motif_seqs, pos_seqs, window_size=24):
365
+ """
366
+ Function to extract fixed, equal length sequences centered at high-attention motif instance.
367
+
368
+ Returns new dict containing seqs with fixed window_size.
369
+
370
+ Arguments:
371
+ motif_seqs -- nested dict, with the following structure:
372
+ {motif: {seq_idx: idx, atten_region_pos: (start, end)}}
373
+ where seq_idx indicates indices of pos_seqs containing a motif, and
374
+ atten_region_pos indicates where the high attention region is located.
375
+ pos_seqs -- list, numpy array or pandas series of positive DNA sequences
376
+
377
+ Keyword arguments:
378
+ window_size -- int, specified window size to be final motif length
379
+ (default 24)
380
+
381
+ Returns:
382
+ new_motif_seqs -- nested dict with same structure as `motif_seqs`s
383
+
384
+ """
385
+ new_motif_seqs = {}
386
+
387
+ # extract fixed-length sequences based on window_size
388
+ for motif, instances in motif_seqs.items():
389
+ new_motif_seqs[motif] = {'seq_idx':[], 'atten_region_pos':[], 'seqs': []}
390
+ for i, coord in enumerate(instances['atten_region_pos']):
391
+ atten_len = coord[1] - coord[0]
392
+ if (window_size - atten_len) % 2 == 0: # even
393
+ offset = (window_size - atten_len) / 2
394
+ new_coord = (int(coord[0] - offset), int(coord[1] + offset))
395
+ if (new_coord[0] >=0) & (new_coord[1] < len(pos_seqs[instances['seq_idx'][i]])):
396
+ # append
397
+ new_motif_seqs[motif]['seq_idx'].append(instances['seq_idx'][i])
398
+ new_motif_seqs[motif]['atten_region_pos'].append((new_coord[0], new_coord[1]))
399
+ new_motif_seqs[motif]['seqs'].append(pos_seqs[instances['seq_idx'][i]][new_coord[0]:new_coord[1]])
400
+ else: # odd
401
+ offset1 = (window_size - atten_len) // 2
402
+ offset2 = (window_size - atten_len) // 2 + 1
403
+ new_coord = (int(coord[0] - offset1), int(coord[1] + offset2))
404
+ if (new_coord[0] >=0) & (new_coord[1] < len(pos_seqs[instances['seq_idx'][i]])):
405
+ # append
406
+ new_motif_seqs[motif]['seq_idx'].append(instances['seq_idx'][i])
407
+ new_motif_seqs[motif]['atten_region_pos'].append((new_coord[0], new_coord[1]))
408
+ new_motif_seqs[motif]['seqs'].append(pos_seqs[instances['seq_idx'][i]][new_coord[0]:new_coord[1]])
409
+
410
+ return new_motif_seqs
411
+
412
+
413
+ ### make full pipeline
414
+ def motif_analysis(pos_seqs,
415
+ neg_seqs,
416
+ pos_atten_scores,
417
+ window_size = 24,
418
+ min_len = 4,
419
+ pval_cutoff = 0.005,
420
+ min_n_motif = 3,
421
+ align_all_ties = True,
422
+ save_file_dir = None,
423
+ **kwargs
424
+ ):
425
+
426
+ """
427
+ Wrapper function of full motif analysis tool based on DNABERT-viz.
428
+
429
+ Arguments:
430
+ pos_seqs -- list, numpy array or pandas series of positive DNA sequences
431
+ neg_seqs -- list, numpy array or pandas series of negative DNA sequences
432
+ pos_atten_scores -- numpy array of attention scores for postive DNA sequence
433
+
434
+ Keyword arguments:
435
+ window_size -- int, specified window size to be final motif length
436
+ (default 24)
437
+ min_len -- int, specified minimum length threshold for contiguous region
438
+ (default 5)
439
+ pval_cutoff -- float, cutoff FDR/p-value to declare statistical significance. (default 0.005)
440
+ min_n_motif -- int, minimum instance inside motif to be filtered (default 3)
441
+ align_all_ties -- bool, whether to keep all best alignments when ties encountered (default True)
442
+ save_file_dir -- str, path to save outputs (default None)
443
+ **kwargs -- other input arguments, may include:
444
+ - verbose: bool, verbosity controller
445
+ - atten_cond: custom conditions to filter/select high attention
446
+ (list of boolean arrays)
447
+ - return_idx: whether the indices of the motifs are only returned.
448
+ - align_cond: custom condition used to declare successful alignment.
449
+ default is score > max of (min_len -1) and (1/2 times min length of two motifs aligned)
450
+
451
+ Returns:
452
+ merged_motif_seqs -- nested dict, with the following structure:
453
+ {motif: {seq_idx: idx, atten_region_pos: (start, end)}}
454
+ where seq_idx indicates indices of pos_seqs containing a motif, and
455
+ atten_region_pos indicates where the high attention region is located.
456
+
457
+ """
458
+ from Bio import motifs
459
+ from Bio.Seq import Seq
460
+
461
+ verbose = False
462
+ if 'verbose' in kwargs:
463
+ verbose = kwargs['verbose']
464
+
465
+ if verbose:
466
+ print("*** Begin motif analysis ***")
467
+ pos_seqs = list(pos_seqs)
468
+ neg_seqs = list(neg_seqs)
469
+
470
+ if verbose:
471
+ print("* pos_seqs: {}; neg_seqs: {}".format(len(pos_seqs),len(neg_seqs)))
472
+
473
+ assert len(pos_seqs) == len(pos_atten_scores)
474
+
475
+ max_seq_len = len(max(pos_seqs, key=len))
476
+ motif_seqs = {}
477
+
478
+ ## find the motif regions
479
+ if verbose:
480
+ print("* Finding high attention motif regions")
481
+ for i, score in enumerate(pos_atten_scores):
482
+ seq_len = len(pos_seqs[i])
483
+ score = score[0:seq_len]
484
+
485
+ # handle kwargs
486
+ if 'atten_cond' in kwargs:
487
+ motif_regions = find_high_attention(score, min_len=min_len, cond=kwargs['atten_cond'])
488
+ else:
489
+ motif_regions = find_high_attention(score, min_len=min_len)
490
+
491
+ for motif_idx in motif_regions:
492
+ seq = pos_seqs[i][motif_idx[0]:motif_idx[1]]
493
+ if seq not in motif_seqs:
494
+ motif_seqs[seq] = {'seq_idx': [i], 'atten_region_pos':[(motif_idx[0],motif_idx[1])]}
495
+ else:
496
+ motif_seqs[seq]['seq_idx'].append(i)
497
+ motif_seqs[seq]['atten_region_pos'].append((motif_idx[0],motif_idx[1]))
498
+
499
+
500
+ # filter motifs
501
+ return_idx = False
502
+ if 'return_idx' in kwargs:
503
+ return_idx = kwargs['return_idx']
504
+ kwargs.pop('return_idx')
505
+
506
+ if verbose:
507
+ print("* Filtering motifs by hypergeometric test")
508
+ motifs_to_keep = filter_motifs(pos_seqs,
509
+ neg_seqs,
510
+ list(motif_seqs.keys()),
511
+ cutoff = pval_cutoff,
512
+ return_idx=return_idx,
513
+ **kwargs)
514
+
515
+ motif_seqs = {k: motif_seqs[k] for k in motifs_to_keep}
516
+
517
+ # merge motifs
518
+ if verbose:
519
+ print("* Merging similar motif instances")
520
+ if 'align_cond' in kwargs:
521
+ merged_motif_seqs = merge_motifs(motif_seqs, min_len=min_len,
522
+ align_all_ties = align_all_ties,
523
+ cond=kwargs['align_cond'])
524
+ else:
525
+ merged_motif_seqs = merge_motifs(motif_seqs, min_len=min_len,
526
+ align_all_ties = align_all_ties)
527
+
528
+ # make fixed-length window sequences
529
+ if verbose:
530
+ print("* Making fixed_length window = {}".format(window_size))
531
+ merged_motif_seqs = make_window(merged_motif_seqs, pos_seqs, window_size=window_size)
532
+
533
+ # remove motifs with only few instances
534
+ if verbose:
535
+ print("* Removing motifs with less than {} instances".format(min_n_motif))
536
+ merged_motif_seqs = {k: coords for k, coords in merged_motif_seqs.items() if len(coords['seq_idx']) >= min_n_motif}
537
+
538
+ if save_file_dir is not None:
539
+ if verbose:
540
+ print("* Saving outputs to directory")
541
+ os.makedirs(save_file_dir, exist_ok=True)
542
+ for motif, instances in merged_motif_seqs.items():
543
+ # saving to files
544
+ with open(save_file_dir+'/motif_{}_{}.txt'.format(motif, len(instances['seq_idx'])), 'w') as f:
545
+ for seq in instances['seqs']:
546
+ f.write(seq+'\n')
547
+ # make weblogo
548
+ seqs = [Seq(v) for i,v in enumerate(instances['seqs'])]
549
+ m = motifs.create(seqs)
550
+ m.weblogo(save_file_dir+"/motif_{}_{}_weblogo.png".format(motif, len(instances['seq_idx'])), format='png_print',
551
+ show_fineprint=False, show_ends=False, color_scheme='color_classic')
552
+
553
+ return merged_motif_seqs
save2cache.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import logging
4
+ import os
5
+ import pickle
6
+ import random
7
+ import re
8
+ import shutil
9
+ from typing import Dict, List, Tuple
10
+ from copy import deepcopy
11
+ from multiprocessing import Pool
12
+
13
+ import numpy as np
14
+ import torch
15
+ from torch.nn.utils.rnn import pad_sequence
16
+ from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from tqdm import tqdm, trange
19
+ import itertools
20
+
21
+ from transformers import (
22
+ WEIGHTS_NAME,
23
+ AdamW,
24
+ BertConfig,
25
+ BertForMaskedLM,
26
+ BertTokenizer,
27
+ DNATokenizer,
28
+ #myTokenizer,
29
+ #MotifTokenizer,
30
+ CamembertConfig,
31
+ CamembertForMaskedLM,
32
+ CamembertTokenizer,
33
+ DistilBertConfig,
34
+ DistilBertForMaskedLM,
35
+ DistilBertTokenizer,
36
+ GPT2Config,
37
+ GPT2LMHeadModel,
38
+ GPT2Tokenizer,
39
+ OpenAIGPTConfig,
40
+ OpenAIGPTLMHeadModel,
41
+ OpenAIGPTTokenizer,
42
+ PreTrainedModel,
43
+ PreTrainedTokenizer,
44
+ RobertaConfig,
45
+ RobertaForMaskedLM,
46
+ RobertaTokenizer,
47
+ get_linear_schedule_with_warmup,
48
+ )
49
+
50
+
51
+ try:
52
+ from torch.utils.tensorboard import SummaryWriter
53
+ except ImportError:
54
+ from tensorboardX import SummaryWriter
55
+
56
+
57
+ MODEL_CLASSES = {
58
+ "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
59
+ "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
60
+ "dna": (BertConfig, BertForMaskedLM, DNATokenizer),
61
+ "bert": (BertConfig, BertForMaskedLM, BertTokenizer),
62
+ "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
63
+ "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
64
+ "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
65
+ #"myBert": (BertConfig, BertForMaskedLM, myTokenizer),
66
+ #"motifBert": (BertConfig, BertForMaskedLM, MotifTokenizer)
67
+ }
68
+
69
+ def convert_line_to_example(tokenizer, lines, max_length, add_special_tokens=True):
70
+ examples = tokenizer.batch_encode_plus(lines, add_special_tokens=add_special_tokens, max_length=max_length)["input_ids"]
71
+ return examples
72
+
73
+ class LineByLineTextDataset(Dataset):
74
+ def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
75
+ assert os.path.isfile(file_path)
76
+ # Here, we do not cache the features, operating under the assumption
77
+ # that we will soon use fast multithreaded tokenizers from the
78
+ # `tokenizers` repo everywhere =)
79
+ directory, filename = os.path.split(file_path)
80
+ cached_features_file = os.path.join(
81
+ directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
82
+ )
83
+
84
+ print("Creating features from dataset file at %s", file_path)
85
+
86
+ with open(file_path, encoding="utf-8") as f:
87
+ lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
88
+
89
+ if args.n_process == 1:
90
+ self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
91
+ else:
92
+ n_proc = args.n_process
93
+ p = Pool(n_proc)
94
+ indexes = [0]
95
+ len_slice = int(len(lines)/n_proc)
96
+ for i in range(1, n_proc+1):
97
+ if i != n_proc:
98
+ indexes.append(len_slice*(i))
99
+ else:
100
+ indexes.append(len(lines))
101
+ results = []
102
+ for i in range(n_proc):
103
+ results.append(p.apply_async(convert_line_to_example,[tokenizer, lines[indexes[i]:indexes[i+1]], block_size,]))
104
+ print(str(i) + " start")
105
+ p.close()
106
+ p.join()
107
+
108
+ self.examples = []
109
+ for result in results:
110
+ ids = result.get()
111
+ self.examples.extend(ids)
112
+ print("Saving features into cached file %s", cached_features_file)
113
+ with open(cached_features_file, "wb") as handle:
114
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
115
+
116
+ def __len__(self):
117
+ return len(self.examples)
118
+
119
+ def __getitem__(self, i):
120
+ return torch.tensor(self.examples[i], dtype=torch.long)
121
+
122
+
123
+ def load_and_cache_examples(args, tokenizer, evaluate=False):
124
+ file_path = args.eval_data_file if evaluate else args.train_data_file
125
+ print(file_path)
126
+ if args.line_by_line:
127
+ return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
128
+ else:
129
+ return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
130
+
131
+
132
+ def main():
133
+
134
+ if args.eval_data_file:
135
+ eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
136
+ print('done')
137
+
138
+ if args.train_data_file:
139
+ train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
140
+
141
+
142
+ if __name__ == '__main__':
143
+
144
+ parser = argparse.ArgumentParser()
145
+
146
+ # Required parameters
147
+ parser.add_argument(
148
+ "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
149
+ )
150
+
151
+ # Other parameters
152
+ parser.add_argument(
153
+ "--eval_data_file",
154
+ default=None,
155
+ type=str,
156
+ help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
157
+ )
158
+ parser.add_argument(
159
+ "--line_by_line",
160
+ action="store_true",
161
+ help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
162
+ )
163
+
164
+ parser.add_argument(
165
+ "--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--tokenizer_name",
170
+ default=None,
171
+ type=str,
172
+ help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--config_name",
177
+ default=None,
178
+ type=str,
179
+ help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
180
+ )
181
+
182
+ parser.add_argument(
183
+ "--block_size",
184
+ default=-1,
185
+ type=int,
186
+ help="Optional input sequence length after tokenization."
187
+ "The training dataset will be truncated in block of this size for training."
188
+ "Default to the model max input length for single sentence inputs (take into account special tokens).",
189
+ )
190
+ parser.add_argument(
191
+ "--specialpath",
192
+ type=str,
193
+ help="Optional input sequence length after tokenization."
194
+ "The training dataset will be truncated in block of this size for training."
195
+ "Default to the model max input length for single sentence inputs (take into account special tokens).",
196
+ )
197
+
198
+
199
+ parser.add_argument("--n_process", type=int, default=1, help="")
200
+ args = parser.parse_args()
201
+
202
+ config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
203
+
204
+ if args.config_name:
205
+ config = config_class.from_pretrained(args.config_name, cache_dir=None)
206
+ else:
207
+ config = config_class()
208
+
209
+ if args.tokenizer_name:
210
+ tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=None)
211
+ else:
212
+ raise ValueError(
213
+ "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
214
+ "and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
215
+ )
216
+
217
+ if args.block_size <= 0:
218
+ args.block_size = tokenizer.max_len
219
+ # Our input block size will be the max possible for the model
220
+ else:
221
+ args.block_size = min(args.block_size, tokenizer.max_len)
222
+
223
+ main()
224
+
setup.cfg ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ ensure_newline_before_comments = True
3
+ force_grid_wrap = 0
4
+ include_trailing_comma = True
5
+ known_first_party = transformers
6
+ known_third_party =
7
+ absl
8
+ fairseq
9
+ fastprogress
10
+ git
11
+ h5py
12
+ MeCab
13
+ nltk
14
+ numpy
15
+ packaging
16
+ PIL
17
+ psutil
18
+ pytorch_lightning
19
+ seqeval
20
+ sklearn
21
+ tensorboardX
22
+ tensorflow
23
+ tensorflow_datasets
24
+ torch
25
+ torchtext
26
+ torchvision
27
+ torch_xla
28
+
29
+ line_length = 119
30
+ lines_after_imports = 2
31
+ multi_line_output = 3
32
+ use_parentheses = True
33
+
34
+ [flake8]
35
+ ignore = E203, E501, W503
36
+ max-line-length = 119
setup.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
3
+
4
+ To create the package for pypi.
5
+
6
+ 1. Change the version in __init__.py, setup.py as well as docs/source/conf.py.
7
+
8
+ 2. Commit these changes with the message: "Release: VERSION"
9
+
10
+ 3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
11
+ Push the tag to git: git push --tags origin master
12
+
13
+ 4. Build both the sources and the wheel. Do not change anything in setup.py between
14
+ creating the wheel and the source distribution (obviously).
15
+
16
+ For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
17
+ (this will build a wheel for the python version you use to build it).
18
+
19
+ For the sources, run: "python setup.py sdist"
20
+ You should now have a /dist directory with both .whl and .tar.gz source versions.
21
+
22
+ 5. Check that everything looks correct by uploading the package to the pypi test server:
23
+
24
+ twine upload dist/* -r pypitest
25
+ (pypi suggest using twine as other methods upload files via plaintext.)
26
+ You may have to specify the repository url, use the following command then:
27
+ twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
28
+
29
+ Check that you can install it in a virtualenv by running:
30
+ pip install -i https://testpypi.python.org/pypi transformers
31
+
32
+ 6. Upload the final version to actual pypi:
33
+ twine upload dist/* -r pypi
34
+
35
+ 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
36
+
37
+ 8. Update the documentation commit in .circleci/deploy.sh for the accurate documentation to be displayed
38
+
39
+ 9. Update README.md to redirect to correct documentation.
40
+ """
41
+
42
+ import shutil
43
+ from pathlib import Path
44
+
45
+ from setuptools import find_packages, setup
46
+
47
+
48
+ # Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
49
+ stale_egg_info = Path(__file__).parent / "transformers.egg-info"
50
+ if stale_egg_info.exists():
51
+ print(
52
+ (
53
+ "Warning: {} exists.\n\n"
54
+ "If you recently updated transformers to 3.0 or later, this is expected,\n"
55
+ "but it may prevent transformers from installing in editable mode.\n\n"
56
+ "This directory is automatically generated by Python's packaging tools.\n"
57
+ "I will remove it now.\n\n"
58
+ "See https://github.com/pypa/pip/issues/5466 for details.\n"
59
+ ).format(stale_egg_info)
60
+ )
61
+ shutil.rmtree(stale_egg_info)
62
+
63
+
64
+ extras = {}
65
+
66
+ extras["mecab"] = ["mecab-python3"]
67
+ extras["sklearn"] = ["scikit-learn"]
68
+ extras["tf"] = ["tensorflow"]
69
+ extras["tf-cpu"] = ["tensorflow-cpu"]
70
+ extras["torch"] = ["torch"]
71
+
72
+ extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
73
+ extras["all"] = extras["serving"] + ["tensorflow", "torch"]
74
+
75
+ extras["testing"] = ["pytest", "pytest-xdist"]
76
+ extras["quality"] = ["black", "isort", "flake8"]
77
+ extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"]
78
+ extras["dev"] = extras["testing"] + extras["quality"] + ["mecab-python3", "scikit-learn", "tensorflow", "torch"]
79
+
80
+ setup(
81
+ name="transformers",
82
+ version="2.5.0",
83
+ author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
84
+ author_email="thomas@huggingface.co",
85
+ description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
86
+ long_description=open("README.md", "r", encoding="utf-8").read(),
87
+ long_description_content_type="text/markdown",
88
+ keywords="NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU",
89
+ license="Apache",
90
+ url="https://github.com/huggingface/transformers",
91
+ package_dir={"": "src"},
92
+ packages=find_packages("src"),
93
+ install_requires=[
94
+ "numpy",
95
+ "tokenizers == 0.5.0",
96
+ # accessing files from S3 directly
97
+ "boto3",
98
+ # filesystem locks e.g. to prevent parallel downloads
99
+ "filelock",
100
+ # for downloading models over HTTPS
101
+ "requests",
102
+ # progress bars in model download and training scripts
103
+ "tqdm >= 4.27",
104
+ # for OpenAI GPT
105
+ "regex != 2019.12.17",
106
+ # for XLNet
107
+ "sentencepiece",
108
+ # for XLM
109
+ "sacremoses",
110
+ ],
111
+ extras_require=extras,
112
+ scripts=["transformers-cli"],
113
+ python_requires=">=3.5.0",
114
+ classifiers=[
115
+ "Development Status :: 5 - Production/Stable",
116
+ "Intended Audience :: Developers",
117
+ "Intended Audience :: Education",
118
+ "Intended Audience :: Science/Research",
119
+ "License :: OSI Approved :: Apache Software License",
120
+ "Operating System :: OS Independent",
121
+ "Programming Language :: Python :: 3",
122
+ "Programming Language :: Python :: 3.5",
123
+ "Programming Language :: Python :: 3.6",
124
+ "Programming Language :: Python :: 3.7",
125
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
126
+ ],
127
+ )
src/transformers/__init__.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all.
4
+
5
+ __version__ = "2.5.0"
6
+
7
+ # Work around to update TensorFlow's absl.logging threshold which alters the
8
+ # default Python logging output behavior when present.
9
+ # see: https://github.com/abseil/abseil-py/issues/99
10
+ # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
11
+ try:
12
+ import absl.logging
13
+ except ImportError:
14
+ pass
15
+ else:
16
+ absl.logging.set_verbosity("info")
17
+ absl.logging.set_stderrthreshold("info")
18
+ absl.logging._warn_preinit_stderr = False
19
+
20
+ import logging
21
+
22
+ from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
23
+ from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
24
+ from .configuration_bart import BartConfig
25
+ from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
26
+ from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
27
+ from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
28
+ from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
29
+ from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
30
+ from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
31
+ from .configuration_mmbt import MMBTConfig
32
+ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
33
+ from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
34
+ from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
35
+ from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
36
+
37
+ # Configurations
38
+ from .configuration_utils import PretrainedConfig
39
+ from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
40
+ from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
41
+ from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
42
+ from .data import (
43
+ DataProcessor,
44
+ InputExample,
45
+ InputFeatures,
46
+ SingleSentenceClassificationProcessor,
47
+ SquadExample,
48
+ SquadFeatures,
49
+ SquadV1Processor,
50
+ SquadV2Processor,
51
+ glue_convert_examples_to_features,
52
+ glue_output_modes,
53
+ glue_processors,
54
+ glue_tasks_num_labels,
55
+ is_sklearn_available,
56
+ squad_convert_examples_to_features,
57
+ xnli_output_modes,
58
+ xnli_processors,
59
+ xnli_tasks_num_labels,
60
+ )
61
+
62
+ # Files and general utilities
63
+ from .file_utils import (
64
+ CONFIG_NAME,
65
+ MODEL_CARD_NAME,
66
+ PYTORCH_PRETRAINED_BERT_CACHE,
67
+ PYTORCH_TRANSFORMERS_CACHE,
68
+ TF2_WEIGHTS_NAME,
69
+ TF_WEIGHTS_NAME,
70
+ TRANSFORMERS_CACHE,
71
+ WEIGHTS_NAME,
72
+ add_end_docstrings,
73
+ add_start_docstrings,
74
+ cached_path,
75
+ is_tf_available,
76
+ is_torch_available,
77
+ )
78
+
79
+ # Model Cards
80
+ from .modelcard import ModelCard
81
+
82
+ # TF 2.0 <=> PyTorch conversion utilities
83
+ from .modeling_tf_pytorch_utils import (
84
+ convert_tf_weight_name_to_pt_weight_name,
85
+ load_pytorch_checkpoint_in_tf2_model,
86
+ load_pytorch_model_in_tf2_model,
87
+ load_pytorch_weights_in_tf2_model,
88
+ load_tf2_checkpoint_in_pytorch_model,
89
+ load_tf2_model_in_pytorch_model,
90
+ load_tf2_weights_in_pytorch_model,
91
+ )
92
+
93
+ # Pipelines
94
+ from .pipelines import (
95
+ CsvPipelineDataFormat,
96
+ FeatureExtractionPipeline,
97
+ FillMaskPipeline,
98
+ JsonPipelineDataFormat,
99
+ NerPipeline,
100
+ PipedPipelineDataFormat,
101
+ Pipeline,
102
+ PipelineDataFormat,
103
+ QuestionAnsweringPipeline,
104
+ TextClassificationPipeline,
105
+ TokenClassificationPipeline,
106
+ pipeline,
107
+ )
108
+ from .tokenization_albert import AlbertTokenizer
109
+ from .tokenization_auto import AutoTokenizer
110
+ from .tokenization_bart import BartTokenizer
111
+ from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
112
+ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
113
+ from .tokenization_camembert import CamembertTokenizer
114
+ from .tokenization_ctrl import CTRLTokenizer
115
+ from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
116
+ from .tokenization_flaubert import FlaubertTokenizer
117
+ from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
118
+ from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
119
+ from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
120
+ from .tokenization_t5 import T5Tokenizer
121
+ from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
122
+ from .tokenization_dna import DNATokenizer
123
+
124
+ # Tokenizers
125
+ from .tokenization_utils import PreTrainedTokenizer
126
+ from .tokenization_xlm import XLMTokenizer
127
+ from .tokenization_xlm_roberta import XLMRobertaTokenizer
128
+ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
129
+
130
+
131
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
132
+
133
+
134
+ if is_sklearn_available():
135
+ from .data import glue_compute_metrics, xnli_compute_metrics
136
+
137
+
138
+ # Modeling
139
+ if is_torch_available():
140
+ from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
141
+ from .modeling_auto import (
142
+ AutoModel,
143
+ AutoModelForPreTraining,
144
+ AutoModelForSequenceClassification,
145
+ AutoModelForQuestionAnswering,
146
+ AutoModelWithLMHead,
147
+ AutoModelForTokenClassification,
148
+ ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
149
+ )
150
+
151
+ from .modeling_bert import (
152
+ BertPreTrainedModel,
153
+ BertModel,
154
+ BertForPreTraining,
155
+ BertForMaskedLM,
156
+ BertForNextSentencePrediction,
157
+ BertForSequenceClassification,
158
+ BertForLongSequenceClassification,
159
+ BertForLongSequenceClassificationCat,
160
+ BertForMultipleChoice,
161
+ BertForTokenClassification,
162
+ BertForQuestionAnswering,
163
+ load_tf_weights_in_bert,
164
+ BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
165
+ )
166
+ from .modeling_openai import (
167
+ OpenAIGPTPreTrainedModel,
168
+ OpenAIGPTModel,
169
+ OpenAIGPTLMHeadModel,
170
+ OpenAIGPTDoubleHeadsModel,
171
+ load_tf_weights_in_openai_gpt,
172
+ OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
173
+ )
174
+ from .modeling_transfo_xl import (
175
+ TransfoXLPreTrainedModel,
176
+ TransfoXLModel,
177
+ TransfoXLLMHeadModel,
178
+ AdaptiveEmbedding,
179
+ load_tf_weights_in_transfo_xl,
180
+ TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
181
+ )
182
+ from .modeling_gpt2 import (
183
+ GPT2PreTrainedModel,
184
+ GPT2Model,
185
+ GPT2LMHeadModel,
186
+ GPT2DoubleHeadsModel,
187
+ load_tf_weights_in_gpt2,
188
+ GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
189
+ )
190
+ from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
191
+ from .modeling_xlnet import (
192
+ XLNetPreTrainedModel,
193
+ XLNetModel,
194
+ XLNetLMHeadModel,
195
+ XLNetForSequenceClassification,
196
+ XLNetForTokenClassification,
197
+ XLNetForMultipleChoice,
198
+ XLNetForQuestionAnsweringSimple,
199
+ XLNetForQuestionAnswering,
200
+ load_tf_weights_in_xlnet,
201
+ XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
202
+ )
203
+ from .modeling_xlm import (
204
+ XLMPreTrainedModel,
205
+ XLMModel,
206
+ XLMWithLMHeadModel,
207
+ XLMForSequenceClassification,
208
+ XLMForQuestionAnswering,
209
+ XLMForQuestionAnsweringSimple,
210
+ XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
211
+ )
212
+ from .modeling_bart import BartForSequenceClassification, BartModel, BartForMaskedLM
213
+ from .modeling_roberta import (
214
+ RobertaForMaskedLM,
215
+ RobertaModel,
216
+ RobertaForSequenceClassification,
217
+ RobertaForMultipleChoice,
218
+ RobertaForTokenClassification,
219
+ RobertaForQuestionAnswering,
220
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
221
+ )
222
+ from .modeling_camembert import (
223
+ CamembertForMaskedLM,
224
+ CamembertModel,
225
+ CamembertForSequenceClassification,
226
+ CamembertForTokenClassification,
227
+ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
228
+ )
229
+ from .modeling_distilbert import (
230
+ DistilBertPreTrainedModel,
231
+ DistilBertForMaskedLM,
232
+ DistilBertModel,
233
+ DistilBertForSequenceClassification,
234
+ DistilBertForQuestionAnswering,
235
+ DistilBertForTokenClassification,
236
+ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
237
+ )
238
+ from .modeling_camembert import (
239
+ CamembertForMaskedLM,
240
+ CamembertModel,
241
+ CamembertForSequenceClassification,
242
+ CamembertForMultipleChoice,
243
+ CamembertForTokenClassification,
244
+ CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
245
+ )
246
+ from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
247
+ from .modeling_t5 import (
248
+ T5PreTrainedModel,
249
+ T5Model,
250
+ T5WithLMHeadModel,
251
+ load_tf_weights_in_t5,
252
+ T5_PRETRAINED_MODEL_ARCHIVE_MAP,
253
+ )
254
+ from .modeling_albert import (
255
+ AlbertPreTrainedModel,
256
+ AlbertModel,
257
+ AlbertForMaskedLM,
258
+ AlbertForSequenceClassification,
259
+ AlbertForQuestionAnswering,
260
+ load_tf_weights_in_albert,
261
+ ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
262
+ )
263
+ from .modeling_xlm_roberta import (
264
+ XLMRobertaForMaskedLM,
265
+ XLMRobertaModel,
266
+ XLMRobertaForMultipleChoice,
267
+ XLMRobertaForSequenceClassification,
268
+ XLMRobertaForTokenClassification,
269
+ XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
270
+ )
271
+ from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
272
+
273
+ from .modeling_flaubert import (
274
+ FlaubertModel,
275
+ FlaubertWithLMHeadModel,
276
+ FlaubertForSequenceClassification,
277
+ FlaubertForQuestionAnswering,
278
+ FlaubertForQuestionAnsweringSimple,
279
+ FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
280
+ )
281
+
282
+ # Optimization
283
+ from .optimization import (
284
+ AdamW,
285
+ get_constant_schedule,
286
+ get_constant_schedule_with_warmup,
287
+ get_cosine_schedule_with_warmup,
288
+ get_cosine_with_hard_restarts_schedule_with_warmup,
289
+ get_linear_schedule_with_warmup,
290
+ )
291
+
292
+
293
+ # TensorFlow
294
+ if is_tf_available():
295
+ from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
296
+ from .modeling_tf_auto import (
297
+ TFAutoModel,
298
+ TFAutoModelForPreTraining,
299
+ TFAutoModelForSequenceClassification,
300
+ TFAutoModelForQuestionAnswering,
301
+ TFAutoModelWithLMHead,
302
+ TFAutoModelForTokenClassification,
303
+ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
304
+ )
305
+
306
+ from .modeling_tf_bert import (
307
+ TFBertPreTrainedModel,
308
+ TFBertMainLayer,
309
+ TFBertEmbeddings,
310
+ TFBertModel,
311
+ TFBertForPreTraining,
312
+ TFBertForMaskedLM,
313
+ TFBertForNextSentencePrediction,
314
+ TFBertForSequenceClassification,
315
+ TFBertForMultipleChoice,
316
+ TFBertForTokenClassification,
317
+ TFBertForQuestionAnswering,
318
+ TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
319
+ )
320
+
321
+ from .modeling_tf_gpt2 import (
322
+ TFGPT2PreTrainedModel,
323
+ TFGPT2MainLayer,
324
+ TFGPT2Model,
325
+ TFGPT2LMHeadModel,
326
+ TFGPT2DoubleHeadsModel,
327
+ TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
328
+ )
329
+
330
+ from .modeling_tf_openai import (
331
+ TFOpenAIGPTPreTrainedModel,
332
+ TFOpenAIGPTMainLayer,
333
+ TFOpenAIGPTModel,
334
+ TFOpenAIGPTLMHeadModel,
335
+ TFOpenAIGPTDoubleHeadsModel,
336
+ TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
337
+ )
338
+
339
+ from .modeling_tf_transfo_xl import (
340
+ TFTransfoXLPreTrainedModel,
341
+ TFTransfoXLMainLayer,
342
+ TFTransfoXLModel,
343
+ TFTransfoXLLMHeadModel,
344
+ TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
345
+ )
346
+
347
+ from .modeling_tf_xlnet import (
348
+ TFXLNetPreTrainedModel,
349
+ TFXLNetMainLayer,
350
+ TFXLNetModel,
351
+ TFXLNetLMHeadModel,
352
+ TFXLNetForSequenceClassification,
353
+ TFXLNetForTokenClassification,
354
+ TFXLNetForQuestionAnsweringSimple,
355
+ TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
356
+ )
357
+
358
+ from .modeling_tf_xlm import (
359
+ TFXLMPreTrainedModel,
360
+ TFXLMMainLayer,
361
+ TFXLMModel,
362
+ TFXLMWithLMHeadModel,
363
+ TFXLMForSequenceClassification,
364
+ TFXLMForQuestionAnsweringSimple,
365
+ TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
366
+ )
367
+
368
+ from .modeling_tf_xlm_roberta import (
369
+ TFXLMRobertaForMaskedLM,
370
+ TFXLMRobertaModel,
371
+ TFXLMRobertaForSequenceClassification,
372
+ TFXLMRobertaForTokenClassification,
373
+ TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
374
+ )
375
+
376
+ from .modeling_tf_roberta import (
377
+ TFRobertaPreTrainedModel,
378
+ TFRobertaMainLayer,
379
+ TFRobertaModel,
380
+ TFRobertaForMaskedLM,
381
+ TFRobertaForSequenceClassification,
382
+ TFRobertaForTokenClassification,
383
+ TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
384
+ )
385
+
386
+ from .modeling_tf_camembert import (
387
+ TFCamembertModel,
388
+ TFCamembertForMaskedLM,
389
+ TFCamembertForSequenceClassification,
390
+ TFCamembertForTokenClassification,
391
+ TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
392
+ )
393
+
394
+ from .modeling_tf_distilbert import (
395
+ TFDistilBertPreTrainedModel,
396
+ TFDistilBertMainLayer,
397
+ TFDistilBertModel,
398
+ TFDistilBertForMaskedLM,
399
+ TFDistilBertForSequenceClassification,
400
+ TFDistilBertForTokenClassification,
401
+ TFDistilBertForQuestionAnswering,
402
+ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
403
+ )
404
+
405
+ from .modeling_tf_ctrl import (
406
+ TFCTRLPreTrainedModel,
407
+ TFCTRLModel,
408
+ TFCTRLLMHeadModel,
409
+ TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
410
+ )
411
+
412
+ from .modeling_tf_albert import (
413
+ TFAlbertPreTrainedModel,
414
+ TFAlbertModel,
415
+ TFAlbertForMaskedLM,
416
+ TFAlbertForSequenceClassification,
417
+ TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
418
+ )
419
+
420
+ from .modeling_tf_t5 import (
421
+ TFT5PreTrainedModel,
422
+ TFT5Model,
423
+ TFT5WithLMHeadModel,
424
+ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
425
+ )
426
+
427
+ # Optimization
428
+ from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
429
+
430
+
431
+ if not is_tf_available() and not is_torch_available():
432
+ logger.warning(
433
+ "Neither PyTorch nor TensorFlow >= 2.0 have been found."
434
+ "Models won't be available and only tokenizers, configuration"
435
+ "and file/data utilities can be used."
436
+ )
src/transformers/activations.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def swish(x):
8
+ return x * torch.sigmoid(x)
9
+
10
+
11
+ def _gelu_python(x):
12
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
13
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
14
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
15
+ This is now written in C in torch.nn.functional
16
+ Also see https://arxiv.org/abs/1606.08415
17
+ """
18
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
19
+
20
+
21
+ gelu = getattr(F, "gelu", _gelu_python)
22
+
23
+
24
+ def gelu_new(x):
25
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
26
+ Also see https://arxiv.org/abs/1606.08415
27
+ """
28
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
29
+
30
+
31
+ ACT2FN = {
32
+ "relu": F.relu,
33
+ "swish": swish,
34
+ "gelu": gelu,
35
+ "tanh": F.tanh,
36
+ "gelu_new": gelu_new,
37
+ }
38
+
39
+
40
+ def get_activation(activation_string):
41
+ if activation_string in ACT2FN:
42
+ return ACT2FN[activation_string]
43
+ else:
44
+ raise KeyError(
45
+ "function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
46
+ activation_string, list(ACT2FN.keys())
47
+ )
48
+ )
src/transformers/commands/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from argparse import ArgumentParser
3
+
4
+
5
+ class BaseTransformersCLICommand(ABC):
6
+ @staticmethod
7
+ @abstractmethod
8
+ def register_subcommand(parser: ArgumentParser):
9
+ raise NotImplementedError()
10
+
11
+ @abstractmethod
12
+ def run(self):
13
+ raise NotImplementedError()
src/transformers/commands/convert.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser, Namespace
2
+ from logging import getLogger
3
+
4
+ from transformers.commands import BaseTransformersCLICommand
5
+
6
+
7
+ def convert_command_factory(args: Namespace):
8
+ """
9
+ Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
10
+ :return: ServeCommand
11
+ """
12
+ return ConvertCommand(
13
+ args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
14
+ )
15
+
16
+
17
+ class ConvertCommand(BaseTransformersCLICommand):
18
+ @staticmethod
19
+ def register_subcommand(parser: ArgumentParser):
20
+ """
21
+ Register this command to argparse so it's available for the transformer-cli
22
+ :param parser: Root parser to register command-specific arguments
23
+ :return:
24
+ """
25
+ train_parser = parser.add_parser(
26
+ "convert",
27
+ help="CLI tool to run convert model from original "
28
+ "author checkpoints to Transformers PyTorch checkpoints.",
29
+ )
30
+ train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
31
+ train_parser.add_argument(
32
+ "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
33
+ )
34
+ train_parser.add_argument(
35
+ "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
36
+ )
37
+ train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
38
+ train_parser.add_argument(
39
+ "--finetuning_task_name",
40
+ type=str,
41
+ default=None,
42
+ help="Optional fine-tuning task name if the TF model was a finetuned model.",
43
+ )
44
+ train_parser.set_defaults(func=convert_command_factory)
45
+
46
+ def __init__(
47
+ self,
48
+ model_type: str,
49
+ tf_checkpoint: str,
50
+ pytorch_dump_output: str,
51
+ config: str,
52
+ finetuning_task_name: str,
53
+ *args
54
+ ):
55
+ self._logger = getLogger("transformers-cli/converting")
56
+
57
+ self._logger.info("Loading model {}".format(model_type))
58
+ self._model_type = model_type
59
+ self._tf_checkpoint = tf_checkpoint
60
+ self._pytorch_dump_output = pytorch_dump_output
61
+ self._config = config
62
+ self._finetuning_task_name = finetuning_task_name
63
+
64
+ def run(self):
65
+ if self._model_type == "bert":
66
+ try:
67
+ from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
68
+ convert_tf_checkpoint_to_pytorch,
69
+ )
70
+ except ImportError:
71
+ msg = (
72
+ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
73
+ "In that case, it requires TensorFlow to be installed. Please see "
74
+ "https://www.tensorflow.org/install/ for installation instructions."
75
+ )
76
+ raise ImportError(msg)
77
+
78
+ convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
79
+ elif self._model_type == "gpt":
80
+ from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
81
+ convert_openai_checkpoint_to_pytorch,
82
+ )
83
+
84
+ convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
85
+ elif self._model_type == "transfo_xl":
86
+ try:
87
+ from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
88
+ convert_transfo_xl_checkpoint_to_pytorch,
89
+ )
90
+ except ImportError:
91
+ msg = (
92
+ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
93
+ "In that case, it requires TensorFlow to be installed. Please see "
94
+ "https://www.tensorflow.org/install/ for installation instructions."
95
+ )
96
+ raise ImportError(msg)
97
+
98
+ if "ckpt" in self._tf_checkpoint.lower():
99
+ TF_CHECKPOINT = self._tf_checkpoint
100
+ TF_DATASET_FILE = ""
101
+ else:
102
+ TF_DATASET_FILE = self._tf_checkpoint
103
+ TF_CHECKPOINT = ""
104
+ convert_transfo_xl_checkpoint_to_pytorch(
105
+ TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
106
+ )
107
+ elif self._model_type == "gpt2":
108
+ try:
109
+ from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
110
+ convert_gpt2_checkpoint_to_pytorch,
111
+ )
112
+ except ImportError:
113
+ msg = (
114
+ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
115
+ "In that case, it requires TensorFlow to be installed. Please see "
116
+ "https://www.tensorflow.org/install/ for installation instructions."
117
+ )
118
+ raise ImportError(msg)
119
+
120
+ convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
121
+ elif self._model_type == "xlnet":
122
+ try:
123
+ from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
124
+ convert_xlnet_checkpoint_to_pytorch,
125
+ )
126
+ except ImportError:
127
+ msg = (
128
+ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
129
+ "In that case, it requires TensorFlow to be installed. Please see "
130
+ "https://www.tensorflow.org/install/ for installation instructions."
131
+ )
132
+ raise ImportError(msg)
133
+
134
+ convert_xlnet_checkpoint_to_pytorch(
135
+ self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
136
+ )
137
+ elif self._model_type == "xlm":
138
+ from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
139
+ convert_xlm_checkpoint_to_pytorch,
140
+ )
141
+
142
+ convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
143
+ else:
144
+ raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]")
src/transformers/commands/download.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+
3
+ from transformers.commands import BaseTransformersCLICommand
4
+
5
+
6
+ def download_command_factory(args):
7
+ return DownloadCommand(args.model, args.cache_dir, args.force)
8
+
9
+
10
+ class DownloadCommand(BaseTransformersCLICommand):
11
+ @staticmethod
12
+ def register_subcommand(parser: ArgumentParser):
13
+ download_parser = parser.add_parser("download")
14
+ download_parser.add_argument(
15
+ "--cache-dir", type=str, default=None, help="Path to location to store the models"
16
+ )
17
+ download_parser.add_argument(
18
+ "--force", action="store_true", help="Force the model to be download even if already in cache-dir"
19
+ )
20
+ download_parser.add_argument("model", type=str, help="Name of the model to download")
21
+ download_parser.set_defaults(func=download_command_factory)
22
+
23
+ def __init__(self, model: str, cache: str, force: bool):
24
+ self._model = model
25
+ self._cache = cache
26
+ self._force = force
27
+
28
+ def run(self):
29
+ from transformers import AutoModel, AutoTokenizer
30
+
31
+ AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
32
+ AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
src/transformers/commands/env.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from argparse import ArgumentParser
3
+
4
+ from transformers import __version__ as version
5
+ from transformers import is_tf_available, is_torch_available
6
+ from transformers.commands import BaseTransformersCLICommand
7
+
8
+
9
+ def info_command_factory(_):
10
+ return EnvironmentCommand()
11
+
12
+
13
+ class EnvironmentCommand(BaseTransformersCLICommand):
14
+ @staticmethod
15
+ def register_subcommand(parser: ArgumentParser):
16
+ download_parser = parser.add_parser("env")
17
+ download_parser.set_defaults(func=info_command_factory)
18
+
19
+ def run(self):
20
+ pt_version = "not installed"
21
+ pt_cuda_available = "NA"
22
+ if is_torch_available():
23
+ import torch
24
+
25
+ pt_version = torch.__version__
26
+ pt_cuda_available = torch.cuda.is_available()
27
+
28
+ tf_version = "not installed"
29
+ tf_cuda_available = "NA"
30
+ if is_tf_available():
31
+ import tensorflow as tf
32
+
33
+ tf_version = tf.__version__
34
+ try:
35
+ # deprecated in v2.1
36
+ tf_cuda_available = tf.test.is_gpu_available()
37
+ except AttributeError:
38
+ # returns list of devices, convert to bool
39
+ tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
40
+
41
+ info = {
42
+ "`transformers` version": version,
43
+ "Platform": platform.platform(),
44
+ "Python version": platform.python_version(),
45
+ "PyTorch version (GPU?)": "{} ({})".format(pt_version, pt_cuda_available),
46
+ "Tensorflow version (GPU?)": "{} ({})".format(tf_version, tf_cuda_available),
47
+ "Using GPU in script?": "<fill in>",
48
+ "Using distributed or parallel set-up in script?": "<fill in>",
49
+ }
50
+
51
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
52
+ print(self.format_dict(info))
53
+
54
+ return info
55
+
56
+ @staticmethod
57
+ def format_dict(d):
58
+ return "\n".join(["- {}: {}".format(prop, val) for prop, val in d.items()]) + "\n"
src/transformers/commands/run.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+
4
+ from transformers.commands import BaseTransformersCLICommand
5
+ from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
6
+
7
+
8
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
9
+
10
+
11
+ def try_infer_format_from_ext(path: str):
12
+ if not path:
13
+ return "pipe"
14
+
15
+ for ext in PipelineDataFormat.SUPPORTED_FORMATS:
16
+ if path.endswith(ext):
17
+ return ext
18
+
19
+ raise Exception(
20
+ "Unable to determine file format from file extension {}. "
21
+ "Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
22
+ )
23
+
24
+
25
+ def run_command_factory(args):
26
+ nlp = pipeline(
27
+ task=args.task,
28
+ model=args.model if args.model else None,
29
+ config=args.config,
30
+ tokenizer=args.tokenizer,
31
+ device=args.device,
32
+ )
33
+ format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
34
+ reader = PipelineDataFormat.from_str(
35
+ format=format,
36
+ output_path=args.output,
37
+ input_path=args.input,
38
+ column=args.column if args.column else nlp.default_input_names,
39
+ overwrite=args.overwrite,
40
+ )
41
+ return RunCommand(nlp, reader)
42
+
43
+
44
+ class RunCommand(BaseTransformersCLICommand):
45
+ def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
46
+ self._nlp = nlp
47
+ self._reader = reader
48
+
49
+ @staticmethod
50
+ def register_subcommand(parser: ArgumentParser):
51
+ run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
52
+ run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
53
+ run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
54
+ run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
55
+ run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
56
+ run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
57
+ run_parser.add_argument(
58
+ "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
59
+ )
60
+ run_parser.add_argument(
61
+ "--column",
62
+ type=str,
63
+ help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
64
+ )
65
+ run_parser.add_argument(
66
+ "--format",
67
+ type=str,
68
+ default="infer",
69
+ choices=PipelineDataFormat.SUPPORTED_FORMATS,
70
+ help="Input format to read from",
71
+ )
72
+ run_parser.add_argument(
73
+ "--device",
74
+ type=int,
75
+ default=-1,
76
+ help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
77
+ )
78
+ run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
79
+ run_parser.set_defaults(func=run_command_factory)
80
+
81
+ def run(self):
82
+ nlp, outputs = self._nlp, []
83
+
84
+ for entry in self._reader:
85
+ output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
86
+ if isinstance(output, dict):
87
+ outputs.append(output)
88
+ else:
89
+ outputs += output
90
+
91
+ # Saving data
92
+ if self._nlp.binary_output:
93
+ binary_path = self._reader.save_binary(outputs)
94
+ logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
95
+ else:
96
+ self._reader.save(outputs)
src/transformers/commands/serving.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser, Namespace
3
+ from typing import Any, List, Optional
4
+
5
+ from transformers import Pipeline
6
+ from transformers.commands import BaseTransformersCLICommand
7
+ from transformers.pipelines import SUPPORTED_TASKS, pipeline
8
+
9
+
10
+ try:
11
+ from uvicorn import run
12
+ from fastapi import FastAPI, HTTPException, Body
13
+ from fastapi.routing import APIRoute
14
+ from pydantic import BaseModel
15
+ from starlette.responses import JSONResponse
16
+
17
+ _serve_dependencies_installed = True
18
+ except (ImportError, AttributeError):
19
+ BaseModel = object
20
+
21
+ def Body(*x, **y):
22
+ pass
23
+
24
+ _serve_dependencies_installed = False
25
+
26
+
27
+ logger = logging.getLogger("transformers-cli/serving")
28
+
29
+
30
+ def serve_command_factory(args: Namespace):
31
+ """
32
+ Factory function used to instantiate serving server from provided command line arguments.
33
+ :return: ServeCommand
34
+ """
35
+ nlp = pipeline(
36
+ task=args.task,
37
+ model=args.model if args.model else None,
38
+ config=args.config,
39
+ tokenizer=args.tokenizer,
40
+ device=args.device,
41
+ )
42
+ return ServeCommand(nlp, args.host, args.port, args.workers)
43
+
44
+
45
+ class ServeModelInfoResult(BaseModel):
46
+ """
47
+ Expose model information
48
+ """
49
+
50
+ infos: dict
51
+
52
+
53
+ class ServeTokenizeResult(BaseModel):
54
+ """
55
+ Tokenize result model
56
+ """
57
+
58
+ tokens: List[str]
59
+ tokens_ids: Optional[List[int]]
60
+
61
+
62
+ class ServeDeTokenizeResult(BaseModel):
63
+ """
64
+ DeTokenize result model
65
+ """
66
+
67
+ text: str
68
+
69
+
70
+ class ServeForwardResult(BaseModel):
71
+ """
72
+ Forward result model
73
+ """
74
+
75
+ output: Any
76
+
77
+
78
+ class ServeCommand(BaseTransformersCLICommand):
79
+ @staticmethod
80
+ def register_subcommand(parser: ArgumentParser):
81
+ """
82
+ Register this command to argparse so it's available for the transformer-cli
83
+ :param parser: Root parser to register command-specific arguments
84
+ :return:
85
+ """
86
+ serve_parser = parser.add_parser(
87
+ "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
88
+ )
89
+ serve_parser.add_argument(
90
+ "--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
91
+ )
92
+ serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
93
+ serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
94
+ serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
95
+ serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
96
+ serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
97
+ serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
98
+ serve_parser.add_argument(
99
+ "--device",
100
+ type=int,
101
+ default=-1,
102
+ help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
103
+ )
104
+ serve_parser.set_defaults(func=serve_command_factory)
105
+
106
+ def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
107
+
108
+ self._pipeline = pipeline
109
+
110
+ self.host = host
111
+ self.port = port
112
+ self.workers = workers
113
+
114
+ if not _serve_dependencies_installed:
115
+ raise RuntimeError(
116
+ "Using serve command requires FastAPI and unicorn. "
117
+ 'Please install transformers with [serving]: pip install "transformers[serving]".'
118
+ "Or install FastAPI and unicorn separately."
119
+ )
120
+ else:
121
+ logger.info("Serving model over {}:{}".format(host, port))
122
+ self._app = FastAPI(
123
+ routes=[
124
+ APIRoute(
125
+ "/",
126
+ self.model_info,
127
+ response_model=ServeModelInfoResult,
128
+ response_class=JSONResponse,
129
+ methods=["GET"],
130
+ ),
131
+ APIRoute(
132
+ "/tokenize",
133
+ self.tokenize,
134
+ response_model=ServeTokenizeResult,
135
+ response_class=JSONResponse,
136
+ methods=["POST"],
137
+ ),
138
+ APIRoute(
139
+ "/detokenize",
140
+ self.detokenize,
141
+ response_model=ServeDeTokenizeResult,
142
+ response_class=JSONResponse,
143
+ methods=["POST"],
144
+ ),
145
+ APIRoute(
146
+ "/forward",
147
+ self.forward,
148
+ response_model=ServeForwardResult,
149
+ response_class=JSONResponse,
150
+ methods=["POST"],
151
+ ),
152
+ ],
153
+ timeout=600,
154
+ )
155
+
156
+ def run(self):
157
+ run(self._app, host=self.host, port=self.port, workers=self.workers)
158
+
159
+ def model_info(self):
160
+ return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
161
+
162
+ def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
163
+ """
164
+ Tokenize the provided input and eventually returns corresponding tokens id:
165
+ - **text_input**: String to tokenize
166
+ - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
167
+ """
168
+ try:
169
+ tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
170
+
171
+ if return_ids:
172
+ tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
173
+ return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
174
+ else:
175
+ return ServeTokenizeResult(tokens=tokens_txt)
176
+
177
+ except Exception as e:
178
+ raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
179
+
180
+ def detokenize(
181
+ self,
182
+ tokens_ids: List[int] = Body(None, embed=True),
183
+ skip_special_tokens: bool = Body(False, embed=True),
184
+ cleanup_tokenization_spaces: bool = Body(True, embed=True),
185
+ ):
186
+ """
187
+ Detokenize the provided tokens ids to readable text:
188
+ - **tokens_ids**: List of tokens ids
189
+ - **skip_special_tokens**: Flag indicating to not try to decode special tokens
190
+ - **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
191
+ """
192
+ try:
193
+ decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
194
+ return ServeDeTokenizeResult(model="", text=decoded_str)
195
+ except Exception as e:
196
+ raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
197
+
198
+ async def forward(self, inputs=Body(None, embed=True)):
199
+ """
200
+ **inputs**:
201
+ **attention_mask**:
202
+ **tokens_type_ids**:
203
+ """
204
+
205
+ # Check we don't have empty string
206
+ if len(inputs) == 0:
207
+ return ServeForwardResult(output=[], attention=[])
208
+
209
+ try:
210
+ # Forward through the model
211
+ output = self._pipeline(inputs)
212
+ return ServeForwardResult(output=output)
213
+ except Exception as e:
214
+ raise HTTPException(500, {"error": str(e)})
src/transformers/commands/train.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import ArgumentParser, Namespace
3
+ from logging import getLogger
4
+
5
+ from transformers import SingleSentenceClassificationProcessor as Processor
6
+ from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
7
+ from transformers.commands import BaseTransformersCLICommand
8
+
9
+
10
+ if not is_tf_available() and not is_torch_available():
11
+ raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
12
+
13
+ # TF training parameters
14
+ USE_XLA = False
15
+ USE_AMP = False
16
+
17
+
18
+ def train_command_factory(args: Namespace):
19
+ """
20
+ Factory function used to instantiate serving server from provided command line arguments.
21
+ :return: ServeCommand
22
+ """
23
+ return TrainCommand(args)
24
+
25
+
26
+ class TrainCommand(BaseTransformersCLICommand):
27
+ @staticmethod
28
+ def register_subcommand(parser: ArgumentParser):
29
+ """
30
+ Register this command to argparse so it's available for the transformer-cli
31
+ :param parser: Root parser to register command-specific arguments
32
+ :return:
33
+ """
34
+ train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
35
+
36
+ train_parser.add_argument(
37
+ "--train_data",
38
+ type=str,
39
+ required=True,
40
+ help="path to train (and optionally evaluation) dataset as a csv with "
41
+ "tab separated labels and sentences.",
42
+ )
43
+ train_parser.add_argument(
44
+ "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
45
+ )
46
+ train_parser.add_argument(
47
+ "--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
48
+ )
49
+ train_parser.add_argument(
50
+ "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
51
+ )
52
+ train_parser.add_argument(
53
+ "--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
54
+ )
55
+
56
+ train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
57
+ train_parser.add_argument(
58
+ "--validation_split",
59
+ type=float,
60
+ default=0.1,
61
+ help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
62
+ )
63
+
64
+ train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
65
+
66
+ train_parser.add_argument(
67
+ "--task", type=str, default="text_classification", help="Task to train the model on."
68
+ )
69
+ train_parser.add_argument(
70
+ "--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
71
+ )
72
+ train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
73
+ train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
74
+ train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
75
+ train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
76
+ train_parser.set_defaults(func=train_command_factory)
77
+
78
+ def __init__(self, args: Namespace):
79
+ self.logger = getLogger("transformers-cli/training")
80
+
81
+ self.framework = "tf" if is_tf_available() else "torch"
82
+
83
+ os.makedirs(args.output, exist_ok=True)
84
+ assert os.path.isdir(args.output)
85
+ self.output = args.output
86
+
87
+ self.column_label = args.column_label
88
+ self.column_text = args.column_text
89
+ self.column_id = args.column_id
90
+
91
+ self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
92
+ if args.task == "text_classification":
93
+ self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
94
+ elif args.task == "token_classification":
95
+ raise NotImplementedError
96
+ elif args.task == "question_answering":
97
+ raise NotImplementedError
98
+
99
+ self.logger.info("Loading dataset from {}".format(args.train_data))
100
+ self.train_dataset = Processor.create_from_csv(
101
+ args.train_data,
102
+ column_label=args.column_label,
103
+ column_text=args.column_text,
104
+ column_id=args.column_id,
105
+ skip_first_row=args.skip_first_row,
106
+ )
107
+ self.valid_dataset = None
108
+ if args.validation_data:
109
+ self.logger.info("Loading validation dataset from {}".format(args.validation_data))
110
+ self.valid_dataset = Processor.create_from_csv(
111
+ args.validation_data,
112
+ column_label=args.column_label,
113
+ column_text=args.column_text,
114
+ column_id=args.column_id,
115
+ skip_first_row=args.skip_first_row,
116
+ )
117
+
118
+ self.validation_split = args.validation_split
119
+ self.train_batch_size = args.train_batch_size
120
+ self.valid_batch_size = args.valid_batch_size
121
+ self.learning_rate = args.learning_rate
122
+ self.adam_epsilon = args.adam_epsilon
123
+
124
+ def run(self):
125
+ if self.framework == "tf":
126
+ return self.run_tf()
127
+ return self.run_torch()
128
+
129
+ def run_torch(self):
130
+ raise NotImplementedError
131
+
132
+ def run_tf(self):
133
+ self.pipeline.fit(
134
+ self.train_dataset,
135
+ validation_data=self.valid_dataset,
136
+ validation_split=self.validation_split,
137
+ learning_rate=self.learning_rate,
138
+ adam_epsilon=self.adam_epsilon,
139
+ train_batch_size=self.train_batch_size,
140
+ valid_batch_size=self.valid_batch_size,
141
+ )
142
+
143
+ # Save trained pipeline
144
+ self.pipeline.save_pretrained(self.output)
src/transformers/commands/user.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from argparse import ArgumentParser
4
+ from getpass import getpass
5
+ from typing import List, Union
6
+
7
+ from requests.exceptions import HTTPError
8
+
9
+ from transformers.commands import BaseTransformersCLICommand
10
+ from transformers.hf_api import HfApi, HfFolder
11
+
12
+
13
+ UPLOAD_MAX_FILES = 15
14
+
15
+
16
+ class UserCommands(BaseTransformersCLICommand):
17
+ @staticmethod
18
+ def register_subcommand(parser: ArgumentParser):
19
+ login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
20
+ login_parser.set_defaults(func=lambda args: LoginCommand(args))
21
+ whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
22
+ whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
23
+ logout_parser = parser.add_parser("logout", help="Log out")
24
+ logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
25
+ # s3
26
+ s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
27
+ s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
28
+ ls_parser = s3_subparsers.add_parser("ls")
29
+ ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
30
+ rm_parser = s3_subparsers.add_parser("rm")
31
+ rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
32
+ rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
33
+ # upload
34
+ upload_parser = parser.add_parser("upload")
35
+ upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
36
+ upload_parser.add_argument(
37
+ "--filename", type=str, default=None, help="Optional: override individual object filename on S3."
38
+ )
39
+ upload_parser.set_defaults(func=lambda args: UploadCommand(args))
40
+
41
+
42
+ class ANSI:
43
+ """
44
+ Helper for en.wikipedia.org/wiki/ANSI_escape_code
45
+ """
46
+
47
+ _bold = "\u001b[1m"
48
+ _reset = "\u001b[0m"
49
+
50
+ @classmethod
51
+ def bold(cls, s):
52
+ return "{}{}{}".format(cls._bold, s, cls._reset)
53
+
54
+
55
+ class BaseUserCommand:
56
+ def __init__(self, args):
57
+ self.args = args
58
+ self._api = HfApi()
59
+
60
+
61
+ class LoginCommand(BaseUserCommand):
62
+ def run(self):
63
+ print(
64
+ """
65
+ _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
66
+ _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
67
+ _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
68
+ _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
69
+ _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
70
+
71
+ """
72
+ )
73
+ username = input("Username: ")
74
+ password = getpass()
75
+ try:
76
+ token = self._api.login(username, password)
77
+ except HTTPError as e:
78
+ # probably invalid credentials, display error message.
79
+ print(e)
80
+ exit(1)
81
+ HfFolder.save_token(token)
82
+ print("Login successful")
83
+ print("Your token:", token, "\n")
84
+ print("Your token has been saved to", HfFolder.path_token)
85
+
86
+
87
+ class WhoamiCommand(BaseUserCommand):
88
+ def run(self):
89
+ token = HfFolder.get_token()
90
+ if token is None:
91
+ print("Not logged in")
92
+ exit()
93
+ try:
94
+ user = self._api.whoami(token)
95
+ print(user)
96
+ except HTTPError as e:
97
+ print(e)
98
+
99
+
100
+ class LogoutCommand(BaseUserCommand):
101
+ def run(self):
102
+ token = HfFolder.get_token()
103
+ if token is None:
104
+ print("Not logged in")
105
+ exit()
106
+ HfFolder.delete_token()
107
+ self._api.logout(token)
108
+ print("Successfully logged out.")
109
+
110
+
111
+ class ListObjsCommand(BaseUserCommand):
112
+ def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
113
+ """
114
+ Inspired by:
115
+ stackoverflow.com/a/8356620/593036
116
+ stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
117
+ """
118
+ col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
119
+ row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
120
+ lines = []
121
+ lines.append(row_format.format(*headers))
122
+ lines.append(row_format.format(*["-" * w for w in col_widths]))
123
+ for row in rows:
124
+ lines.append(row_format.format(*row))
125
+ return "\n".join(lines)
126
+
127
+ def run(self):
128
+ token = HfFolder.get_token()
129
+ if token is None:
130
+ print("Not logged in")
131
+ exit(1)
132
+ try:
133
+ objs = self._api.list_objs(token)
134
+ except HTTPError as e:
135
+ print(e)
136
+ exit(1)
137
+ if len(objs) == 0:
138
+ print("No shared file yet")
139
+ exit()
140
+ rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
141
+ print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
142
+
143
+
144
+ class DeleteObjCommand(BaseUserCommand):
145
+ def run(self):
146
+ token = HfFolder.get_token()
147
+ if token is None:
148
+ print("Not logged in")
149
+ exit(1)
150
+ try:
151
+ self._api.delete_obj(token, filename=self.args.filename)
152
+ except HTTPError as e:
153
+ print(e)
154
+ exit(1)
155
+ print("Done")
156
+
157
+
158
+ class UploadCommand(BaseUserCommand):
159
+ def walk_dir(self, rel_path):
160
+ """
161
+ Recursively list all files in a folder.
162
+ """
163
+ entries: List[os.DirEntry] = list(os.scandir(rel_path))
164
+ files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
165
+ for f in entries:
166
+ if f.is_dir():
167
+ files += self.walk_dir(f.path)
168
+ return files
169
+
170
+ def run(self):
171
+ token = HfFolder.get_token()
172
+ if token is None:
173
+ print("Not logged in")
174
+ exit(1)
175
+ local_path = os.path.abspath(self.args.path)
176
+ if os.path.isdir(local_path):
177
+ if self.args.filename is not None:
178
+ raise ValueError("Cannot specify a filename override when uploading a folder.")
179
+ rel_path = os.path.basename(local_path)
180
+ files = self.walk_dir(rel_path)
181
+ elif os.path.isfile(local_path):
182
+ filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
183
+ files = [(local_path, filename)]
184
+ else:
185
+ raise ValueError("Not a valid file or directory: {}".format(local_path))
186
+
187
+ if sys.platform == "win32":
188
+ files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files]
189
+
190
+ if len(files) > UPLOAD_MAX_FILES:
191
+ print(
192
+ "About to upload {} files to S3. This is probably wrong. Please filter files before uploading.".format(
193
+ ANSI.bold(len(files))
194
+ )
195
+ )
196
+ exit(1)
197
+
198
+ for filepath, filename in files:
199
+ print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
200
+
201
+ choice = input("Proceed? [Y/n] ").lower()
202
+ if not (choice == "" or choice == "y" or choice == "yes"):
203
+ print("Abort")
204
+ exit()
205
+ print(ANSI.bold("Uploading... This might take a while if files are large"))
206
+ for filepath, filename in files:
207
+ access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
208
+ print("Your file now lives at:")
209
+ print(access_url)