Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- .gitignore +17 -0
- LICENSE +201 -0
- README.md +380 -0
- SNP/SNP.py +85 -0
- SNP/example_mut_file.txt +6 -0
- SNP/examples/dev.tsv +6 -0
- SNP/mutate_seqs.py +118 -0
- examples/.Rhistory +0 -0
- examples/.run_pretrain.py.swp +0 -0
- examples/6mer_pretrain_emb/static_6mer_embeddings.npy +3 -0
- examples/6mer_pretrain_emb_20ways/static_6mer_embed_20ways.npy +3 -0
- examples/6mer_pretrain_emb_adaptive/static_adaptive_embed.npy +3 -0
- examples/compute_result.py +290 -0
- examples/data_process_template/.process_pretrain_data_multi.py.swp +0 -0
- examples/data_process_template/process_690.py +103 -0
- examples/data_process_template/process_csv.py +311 -0
- examples/data_process_template/process_finetune_data.py +713 -0
- examples/data_process_template/process_ner.py +132 -0
- examples/data_process_template/process_pretrain_data.py +148 -0
- examples/data_process_template/process_pretrain_data_multi.py +63 -0
- examples/data_process_template/process_scan_prom_data.py +76 -0
- examples/gen_cCRE_emb_final.py +113 -0
- examples/load_model_test.py +69 -0
- examples/requirements.txt +11 -0
- examples/run_finetune.py +1284 -0
- examples/run_pretrain.py +885 -0
- examples/run_pretrain.sh.save +36 -0
- examples/sample_data/ft/6/dev.tsv +0 -0
- examples/sample_data/ft/6/train.tsv +3 -0
- examples/sample_data/pre/6_3k.txt +0 -0
- examples/save_static_embeddings.py +65 -0
- examples/scripts/run_mut.sh +45 -0
- examples/scripts/uce.sh +26 -0
- examples/visualize.py +152 -0
- motif/find_motifs.py +112 -0
- motif/motif_utils.py +553 -0
- save2cache.py +224 -0
- setup.cfg +36 -0
- setup.py +127 -0
- src/transformers/__init__.py +436 -0
- src/transformers/activations.py +48 -0
- src/transformers/commands/__init__.py +13 -0
- src/transformers/commands/convert.py +144 -0
- src/transformers/commands/download.py +32 -0
- src/transformers/commands/env.py +58 -0
- src/transformers/commands/run.py +96 -0
- src/transformers/commands/serving.py +214 -0
- src/transformers/commands/train.py +144 -0
- src/transformers/commands/user.py +209 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/sample_data/ft/6/train.tsv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
| 2 |
+
cache*
|
| 3 |
+
dna_cache*
|
| 4 |
+
examples/runs
|
| 5 |
+
examples/ft
|
| 6 |
+
examples/output*
|
| 7 |
+
examples/ft_new
|
| 8 |
+
examples/results
|
| 9 |
+
examples/data_old
|
| 10 |
+
examples/data
|
| 11 |
+
examples/result
|
| 12 |
+
examples/models
|
| 13 |
+
src/transformers/data/__pycache__
|
| 14 |
+
src/transformers/data/metrics/__pycache__
|
| 15 |
+
src/transformers/data/processors/__pycache__
|
| 16 |
+
src/transformers/__pycache__
|
| 17 |
+
src/transformers.egg-info
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DNABERT
|
| 2 |
+
This repository includes the implementation of 'DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome'. Please cite our paper if you use the models or codes. The repo is still actively under development, so please kindly report if there is any issue encountered.
|
| 3 |
+
|
| 4 |
+
In this package, we provides resources including: source codes of the DNABERT model, usage examples, pre-trained models, fine-tuned models and visulization tool. This package is still under development, as more features will be included gradually. Training of DNABERT consists of general-purposed pre-training and task-specific fine-tuning. As a contribution of our project, we released the pre-trained models in this repository. We extended codes from [huggingface](https://github.com/huggingface/transformers) and adapted them to the DNA scenario.
|
| 5 |
+
|
| 6 |
+
## Update 2025/07/08
|
| 7 |
+
|
| 8 |
+
The original links to the pretrained DNABERT models (DNABERT-3, 4, 5, 6) have expired. Please go to HuggingFace to access and download the models:
|
| 9 |
+
|
| 10 |
+
DNABERT-3: https://huggingface.co/zhihan1996/DNA_bert_3
|
| 11 |
+
DNABERT-4: https://huggingface.co/zhihan1996/DNA_bert_4
|
| 12 |
+
DNABERT-5: https://huggingface.co/zhihan1996/DNA_bert_5
|
| 13 |
+
DNABERT-6: https://huggingface.co/zhihan1996/DNA_bert_6
|
| 14 |
+
|
| 15 |
+
## Update 2023/06/26
|
| 16 |
+
|
| 17 |
+
The second generation of DNABERT, named [DNABERT-2](https://arxiv.org/abs/2306.15006), is publically available at https://github.com/Zhihan1996/DNABERT_2. DNABERT-2 is trained on multi-species genomes and is more efficient, powerful, and easy to use than its first generation. We also provide simpler usage of DNABERT in the new package. A comprehensive benchmark Genome Understanding Evaluation (GUE), which contains $28$ datasets on $7$ tasks, is also published. Please check out DNABERT-2 if you are interested in our work. Thanks!
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Citation
|
| 21 |
+
If you have used DNABERT in your research, please kindly cite the following publications:
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
@article{ji2021dnabert,
|
| 25 |
+
author = {Ji, Yanrong and Zhou, Zhihan and Liu, Han and Davuluri, Ramana V},
|
| 26 |
+
title = "{DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome}",
|
| 27 |
+
journal = {Bioinformatics},
|
| 28 |
+
volume = {37},
|
| 29 |
+
number = {15},
|
| 30 |
+
pages = {2112-2120},
|
| 31 |
+
year = {2021},
|
| 32 |
+
month = {02},
|
| 33 |
+
issn = {1367-4803},
|
| 34 |
+
doi = {10.1093/bioinformatics/btab083},
|
| 35 |
+
url = {https://doi.org/10.1093/bioinformatics/btab083},
|
| 36 |
+
eprint = {https://academic.oup.com/bioinformatics/article-pdf/37/15/2112/50578892/btab083.pdf},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@misc{zhou2023dnabert2,
|
| 41 |
+
title={DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome},
|
| 42 |
+
author={Zhihan Zhou and Yanrong Ji and Weijian Li and Pratik Dutta and Ramana Davuluri and Han Liu},
|
| 43 |
+
year={2023},
|
| 44 |
+
eprint={2306.15006},
|
| 45 |
+
archivePrefix={arXiv},
|
| 46 |
+
primaryClass={q-bio.GN}
|
| 47 |
+
}
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
## 1. Environment setup
|
| 52 |
+
|
| 53 |
+
We recommend you to build a python virtual environment with [Anaconda](https://docs.anaconda.com/anaconda/install/linux/). Also, please make sure you have at least one NVIDIA GPU with Linux x86_64 Driver Version >= 410.48 (compatible with CUDA 10.0). We applied distributed training on 8 NVIDIA GeForce RTX 2080 Ti with 11 GB graphic memory, and the batch size corresponds to it. If you use GPU with other specifications and memory sizes, consider adjusting your batch size accordingly.
|
| 54 |
+
|
| 55 |
+
#### 1.1 Create and activate a new virtual environment
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
conda create -n dnabert python=3.6
|
| 59 |
+
conda activate dnabert
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
#### 1.2 Install the package and other requirements
|
| 65 |
+
|
| 66 |
+
(Required)
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
|
| 70 |
+
|
| 71 |
+
git clone https://github.com/jerryji1993/DNABERT
|
| 72 |
+
cd DNABERT
|
| 73 |
+
python3 -m pip install --editable .
|
| 74 |
+
cd examples
|
| 75 |
+
python3 -m pip install -r requirements.txt
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
(Optional, install apex for fp16 training)
|
| 81 |
+
|
| 82 |
+
change to a desired directory by `cd PATH_NAME`
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
git clone https://github.com/NVIDIA/apex
|
| 86 |
+
cd apex
|
| 87 |
+
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
## 2. Pre-train (Skip this section if you fine-tune on pre-trained models)
|
| 95 |
+
|
| 96 |
+
#### 2.1 Data processing
|
| 97 |
+
|
| 98 |
+
Please see the template data at `/example/sample_data/pre`. If you are trying to pre-train DNABERT with your own data, please process you data into the same format as it. Note that the sequences are in kmer format, so you will need to convert your sequences into that. We also provide a custom function `seq2kmer`in `motif/motif_utils.py` for this conversion.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
In the following example, we use DNABERT with kmer=6 as example.
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
#### 2.2 Model Training
|
| 107 |
+
|
| 108 |
+
```
|
| 109 |
+
cd examples
|
| 110 |
+
|
| 111 |
+
export KMER=6
|
| 112 |
+
export TRAIN_FILE=sample_data/pre/6_3k.txt
|
| 113 |
+
export TEST_FILE=sample_data/pre/6_3k.txt
|
| 114 |
+
export SOURCE=PATH_TO_DNABERT_REPO
|
| 115 |
+
export OUTPUT_PATH=output$KMER
|
| 116 |
+
|
| 117 |
+
python run_pretrain.py \
|
| 118 |
+
--output_dir $OUTPUT_PATH \
|
| 119 |
+
--model_type=dna \
|
| 120 |
+
--tokenizer_name=dna$KMER \
|
| 121 |
+
--config_name=$SOURCE/src/transformers/dnabert-config/bert-config-$KMER/config.json \
|
| 122 |
+
--do_train \
|
| 123 |
+
--train_data_file=$TRAIN_FILE \
|
| 124 |
+
--do_eval \
|
| 125 |
+
--eval_data_file=$TEST_FILE \
|
| 126 |
+
--mlm \
|
| 127 |
+
--gradient_accumulation_steps 25 \
|
| 128 |
+
--per_gpu_train_batch_size 10 \
|
| 129 |
+
--per_gpu_eval_batch_size 6 \
|
| 130 |
+
--save_steps 500 \
|
| 131 |
+
--save_total_limit 20 \
|
| 132 |
+
--max_steps 200000 \
|
| 133 |
+
--evaluate_during_training \
|
| 134 |
+
--logging_steps 500 \
|
| 135 |
+
--line_by_line \
|
| 136 |
+
--learning_rate 4e-4 \
|
| 137 |
+
--block_size 512 \
|
| 138 |
+
--adam_epsilon 1e-6 \
|
| 139 |
+
--weight_decay 0.01 \
|
| 140 |
+
--beta1 0.9 \
|
| 141 |
+
--beta2 0.98 \
|
| 142 |
+
--mlm_probability 0.025 \
|
| 143 |
+
--warmup_steps 10000 \
|
| 144 |
+
--overwrite_output_dir \
|
| 145 |
+
--n_process 24
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
## 3. Fine-tune (Skip this section if you use fine-tuned model)
|
| 155 |
+
|
| 156 |
+
#### 3.1 Data processing
|
| 157 |
+
|
| 158 |
+
Please see the template data at `/example/sample_data/ft/`. If you are trying to fine-tune DNABERT with your own data, please process you data into the same format as it. Note that the sequences are in kmer format, so you will need to convert your sequences into that. We also provide a custom function `seq2kmer`in `motif/motif_utils.py` for this conversion.
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
#### 3.2 Download pre-trained DNABERT
|
| 163 |
+
|
| 164 |
+
[DNABERT3](https://drive.google.com/file/d/1nVBaIoiJpnwQxiz4dSq6Sv9kBKfXhZuM/view?usp=sharing)
|
| 165 |
+
|
| 166 |
+
[DNABERT4](https://drive.google.com/file/d/1V7CChcC6KgdJ7Gwdyn73OS6dZR_J-Lrs/view?usp=sharing)
|
| 167 |
+
|
| 168 |
+
[DNABERT5](https://drive.google.com/file/d/1KMqgXYCzrrYD1qxdyNWnmUYPtrhQqRBM/view?usp=sharing)
|
| 169 |
+
|
| 170 |
+
[DNABERT6](https://drive.google.com/file/d/1BJjqb5Dl2lNMg2warsFQ0-Xvn1xxfFXC/view?usp=sharing)
|
| 171 |
+
|
| 172 |
+
Download the pre-trained model in to a directory. (If you would like to replicate the following examples, please download DNABERT 6). Then unzip the package by running:
|
| 173 |
+
|
| 174 |
+
```
|
| 175 |
+
unzip 6-new-12w-0.zip
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
We also provide a model with `KMER=6` that is fine-tuned on the sample dataset for prediction/visulization/motif_analysis. If you use the fine-tuned model instead of fine-tuning a model by your self, please download the fine-tuned and put it under `examples/ft/6`.
|
| 179 |
+
|
| 180 |
+
[Fine-tuned Model](https://drive.google.com/drive/folders/15wFcukTv3ecPw9_25dcOv-bZmj-8d_-6?usp=sharing)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
#### 3.3 Fine-tune with pre-trained model
|
| 184 |
+
|
| 185 |
+
In the following example, we use DNABERT with kmer=6 as example. We use `prom-core`, a 2-class classification task as example.
|
| 186 |
+
|
| 187 |
+
```
|
| 188 |
+
cd examples
|
| 189 |
+
|
| 190 |
+
export KMER=6
|
| 191 |
+
export MODEL_PATH=PATH_TO_THE_PRETRAINED_MODEL
|
| 192 |
+
export DATA_PATH=sample_data/ft/$KMER
|
| 193 |
+
export OUTPUT_PATH=./ft/$KMER
|
| 194 |
+
|
| 195 |
+
python run_finetune.py \
|
| 196 |
+
--model_type dna \
|
| 197 |
+
--tokenizer_name=dna$KMER \
|
| 198 |
+
--model_name_or_path $MODEL_PATH \
|
| 199 |
+
--task_name dnaprom \
|
| 200 |
+
--do_train \
|
| 201 |
+
--do_eval \
|
| 202 |
+
--data_dir $DATA_PATH \
|
| 203 |
+
--max_seq_length 100 \
|
| 204 |
+
--per_gpu_eval_batch_size=32 \
|
| 205 |
+
--per_gpu_train_batch_size=32 \
|
| 206 |
+
--learning_rate 2e-4 \
|
| 207 |
+
--num_train_epochs 5.0 \
|
| 208 |
+
--output_dir $OUTPUT_PATH \
|
| 209 |
+
--evaluate_during_training \
|
| 210 |
+
--logging_steps 100 \
|
| 211 |
+
--save_steps 4000 \
|
| 212 |
+
--warmup_percent 0.1 \
|
| 213 |
+
--hidden_dropout_prob 0.1 \
|
| 214 |
+
--overwrite_output \
|
| 215 |
+
--weight_decay 0.01 \
|
| 216 |
+
--n_process 8
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
|
| 220 |
+
|
| 221 |
+
We also provide a model with `KMER=6` that is fine-tuned on the sample dataset for prediction/visulization/motif_analysis. If you use the fine-tuned model instead of fine-tuning a model by your self, please download the fine-tuned and put it under `examples/ft/6`.
|
| 222 |
+
|
| 223 |
+
[Fine-tuned Model](https://drive.google.com/drive/folders/15wFcukTv3ecPw9_25dcOv-bZmj-8d_-6?usp=sharing)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
## 4. Prediction
|
| 228 |
+
|
| 229 |
+
After the model is fine-tuned, we can get predictions by running
|
| 230 |
+
|
| 231 |
+
```$
|
| 232 |
+
export KMER=6
|
| 233 |
+
export MODEL_PATH=./ft/$KMER
|
| 234 |
+
export DATA_PATH=sample_data/ft/$KMER
|
| 235 |
+
export PREDICTION_PATH=./result/$KMER
|
| 236 |
+
|
| 237 |
+
python run_finetune.py \
|
| 238 |
+
--model_type dna \
|
| 239 |
+
--tokenizer_name=dna$KMER \
|
| 240 |
+
--model_name_or_path $MODEL_PATH \
|
| 241 |
+
--task_name dnaprom \
|
| 242 |
+
--do_predict \
|
| 243 |
+
--data_dir $DATA_PATH \
|
| 244 |
+
--max_seq_length 75 \
|
| 245 |
+
--per_gpu_pred_batch_size=128 \
|
| 246 |
+
--output_dir $MODEL_PATH \
|
| 247 |
+
--predict_dir $PREDICTION_PATH \
|
| 248 |
+
--n_process 48
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
With the above command, the fine-tuned DNABERT model will be loaded from `MODEL_PATH` , and makes prediction on the `dev.tsv` file that saved in `DATA_PATH` and save the prediction result at `PREDICTION_PATH`.
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
## 5. Visualization
|
| 258 |
+
|
| 259 |
+
Visualiazation of DNABERT consists of 2 steps. Calcualate attention scores and Plot.
|
| 260 |
+
|
| 261 |
+
#### 5.1 Calculate attention scores
|
| 262 |
+
|
| 263 |
+
calculate with only one model (For example, DNABERT6)
|
| 264 |
+
|
| 265 |
+
```
|
| 266 |
+
export KMER=6
|
| 267 |
+
export MODEL_PATH=./ft/$KMER
|
| 268 |
+
export DATA_PATH=sample_data/ft/$KMER
|
| 269 |
+
export PREDICTION_PATH=./result/$KMER
|
| 270 |
+
|
| 271 |
+
python run_finetune.py \
|
| 272 |
+
--model_type dna \
|
| 273 |
+
--tokenizer_name=dna$KMER \
|
| 274 |
+
--model_name_or_path $MODEL_PATH \
|
| 275 |
+
--task_name dnaprom \
|
| 276 |
+
--do_visualize \
|
| 277 |
+
--visualize_data_dir $DATA_PATH \
|
| 278 |
+
--visualize_models $KMER \
|
| 279 |
+
--data_dir $DATA_PATH \
|
| 280 |
+
--max_seq_length 81 \
|
| 281 |
+
--per_gpu_pred_batch_size=16 \
|
| 282 |
+
--output_dir $MODEL_PATH \
|
| 283 |
+
--predict_dir $PREDICTION_PATH \
|
| 284 |
+
--n_process 96
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
With the above command, the fine-tuned DNABERT model will be loaded from `MODEL_PATH` , and calculates attention scores on the `dev.tsv` file that saved in `DATA_PATH` and save the result at `PREDICTION_PATH`.
|
| 288 |
+
|
| 289 |
+
Add --fp16 tag if you want to perfrom mixed precision. (You have to install the 'apex' from source first).
|
| 290 |
+
|
| 291 |
+
####5.2 Plotting tool
|
| 292 |
+
|
| 293 |
+
## 6. Motif analysis
|
| 294 |
+
|
| 295 |
+
Once the attention scores are generated, we can proceed further to perform motif analysis using `motif/find_motifs.py`:
|
| 296 |
+
|
| 297 |
+
```
|
| 298 |
+
cd ../motif
|
| 299 |
+
|
| 300 |
+
export KMER=6
|
| 301 |
+
export DATA_PATH=../examples/sample_data/ft/$KMER
|
| 302 |
+
export PREDICTION_PATH=../examples/result/$KMER
|
| 303 |
+
export MOTIF_PATH=./result/$KMER
|
| 304 |
+
|
| 305 |
+
python find_motifs.py \
|
| 306 |
+
--data_dir $DATA_PATH \
|
| 307 |
+
--predict_dir $PREDICTION_PATH \
|
| 308 |
+
--window_size 24 \
|
| 309 |
+
--min_len 5 \
|
| 310 |
+
--pval_cutoff 0.005 \
|
| 311 |
+
--min_n_motif 3 \
|
| 312 |
+
--align_all_ties \
|
| 313 |
+
--save_file_dir $MOTIF_PATH \
|
| 314 |
+
--verbose
|
| 315 |
+
```
|
| 316 |
+
|
| 317 |
+
The script will generate a .txt file and a weblogo .png file for each motif under `MOTIF_PATH`.
|
| 318 |
+
|
| 319 |
+
## 7. Genomic variants analysis
|
| 320 |
+
|
| 321 |
+
To perform genomic variants analysis (e.g. SNPs), we need to first ensure the predictions for the sequences were generated. Then, create a file (template in `SNP/example_mut_file.txt`) specifying for which sequences in `dev.tsv` and start and end indices where we need to perform the mutation. The first column indicates the index of sequence in `dev.tsv` to be mutated. Second and third columns are the start and end indices while the fourth column is the target of mutation (can be substitution, insertion, deletion, etc.)
|
| 322 |
+
|
| 323 |
+
Once such a file is created, we can perform mutation on the sequences:
|
| 324 |
+
|
| 325 |
+
```
|
| 326 |
+
cd ../SNP
|
| 327 |
+
python mutate_seqs.py ./../examples/sample_data/ft/6/dev.tsv ./examples/ --mut_file ./example_mut_file.txt --k 6
|
| 328 |
+
```
|
| 329 |
+
Alternatively, we can choose to leave the `--mut_file` argument blank, where the program would try to perform substitution of all bases to the four possible nucleotides ('A', 'T', 'C', or 'G') for all sequences. This would be useful for plotting a mutation heatmap as included in the paper. **Note that this would be slow if the `dev.tsv` contains a lot of sequences or the input sequences are very long, as the command would try to perform mutation on all possible locations of them**.
|
| 330 |
+
|
| 331 |
+
```
|
| 332 |
+
cd ../SNP
|
| 333 |
+
python mutate_seqs.py ./../examples/sample_data/ft/6/dev.tsv ./examples/ --k 6
|
| 334 |
+
```
|
| 335 |
+
|
| 336 |
+
After that, we can again predict on the generated sequences. **Note: if you have insertion/deletions in your `mut_file.txt`, consider changing the `max_seq_length` we use when making predictions.**
|
| 337 |
+
|
| 338 |
+
```
|
| 339 |
+
export KMER=6
|
| 340 |
+
export MODEL_PATH=../examples/ft/$KMER
|
| 341 |
+
export DATA_PATH=examples
|
| 342 |
+
export PREDICTION_PATH=examples
|
| 343 |
+
|
| 344 |
+
python ../examples/run_finetune.py \
|
| 345 |
+
--model_type dna \
|
| 346 |
+
--tokenizer_name=dna$KMER \
|
| 347 |
+
--model_name_or_path $MODEL_PATH \
|
| 348 |
+
--task_name dnaprom \
|
| 349 |
+
--do_predict \
|
| 350 |
+
--data_dir $DATA_PATH \
|
| 351 |
+
--max_seq_length 75 \
|
| 352 |
+
--per_gpu_pred_batch_size=128 \
|
| 353 |
+
--output_dir $MODEL_PATH \
|
| 354 |
+
--predict_dir $PREDICTION_PATH \
|
| 355 |
+
--n_process 48
|
| 356 |
+
```
|
| 357 |
+
|
| 358 |
+
This will again create `pred_results.npy` file under the `$PREDICTION_PATH`. Once we have all the above, we can compute the effect of these mutations by:
|
| 359 |
+
|
| 360 |
+
```
|
| 361 |
+
python SNP.py \
|
| 362 |
+
--orig_seq_file ../examples/sample_data/ft/6/dev.tsv \
|
| 363 |
+
--orig_pred_file ../examples/result/6/pred_results.npy \
|
| 364 |
+
--mut_seq_file examples/dev.tsv \
|
| 365 |
+
--mut_pred_file examples/pred_results.npy \
|
| 366 |
+
--save_file_dir examples
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
This would save a `mutations.tsv` file under `save_file_dir`, that contains index of original sequence (in original `dev.tsv`), original sequence and predictions, mutated sequence and predictions, as well as the difference score and log odds ratio of the change in every case.
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
## Q&A
|
| 373 |
+
|
| 374 |
+
#### 1. I cannot start training the model/I have installation issues for the dependencies.
|
| 375 |
+
|
| 376 |
+
Please kindly make sure that you satisfied all system requirements for DNABERT, and that you have a conda environment properly set up. We have recently successfully tested our pipeline on Amazon EC2 Deep Learning AMI (Ubuntu 18.04). As an option, you could compare your system/environment setup with this AMI.
|
| 377 |
+
|
| 378 |
+
#### 2. Can DNABERT run on sequences longer than 512?
|
| 379 |
+
|
| 380 |
+
#### 3. Can DNABERT be extended to multi-class classification?
|
SNP/SNP.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#### ::: DNABERT-viz SNP analysis ::: ####
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('../motif')
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
import motif_utils as utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--orig_seq_file",
|
| 16 |
+
default='../examples/sample_data/ft/prom-core/6/dev.tsv',
|
| 17 |
+
type=str,
|
| 18 |
+
required=True,
|
| 19 |
+
help="Path to original input sequence+label .tsv file.",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--orig_pred_file",
|
| 24 |
+
required=True,
|
| 25 |
+
type=str,
|
| 26 |
+
default='../examples/result/prom-core/6/pred.npy',
|
| 27 |
+
help="Path to predictions pred.npy of original sequences.",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--mut_seq_file",
|
| 32 |
+
default='examples/dev.tsv',
|
| 33 |
+
type=str,
|
| 34 |
+
required=True,
|
| 35 |
+
help="Path to mutated sequence+index .tsv file.",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--mut_pred_file",
|
| 40 |
+
required=True,
|
| 41 |
+
type=str,
|
| 42 |
+
default='examples/pred.npy',
|
| 43 |
+
help="Path to predictions pred_results.npy of mutated sequences.",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--save_file_dir",
|
| 48 |
+
default='.',
|
| 49 |
+
type=str,
|
| 50 |
+
help="Path to save outputs",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# TODO: add the conditions
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
# original sequences
|
| 57 |
+
# orig_pred = np.load(args.orig_pred_file)
|
| 58 |
+
orig_dev = pd.read_csv(args.orig_seq_file,sep='\t',header=0)
|
| 59 |
+
orig_dev.columns = ['sequence','label']
|
| 60 |
+
orig_dev['orig_seq'] = orig_dev['sequence'].apply(utils.kmer2seq)
|
| 61 |
+
orig_dev['idx'] = orig_dev.index
|
| 62 |
+
|
| 63 |
+
orig_pred = np.load(args.orig_pred_file)
|
| 64 |
+
orig_dev['orig_pred'] = orig_pred
|
| 65 |
+
|
| 66 |
+
# mutated sequences
|
| 67 |
+
# mut_pred = np.load(args.mut_pred_file)
|
| 68 |
+
mut_dev = pd.read_csv(args.mut_seq_file,sep='\t',header=0)
|
| 69 |
+
mut_dev.columns = ['sequence','label','idx'] #ignore label
|
| 70 |
+
mut_dev['mut_seq'] = mut_dev['sequence'].apply(utils.kmer2seq)
|
| 71 |
+
|
| 72 |
+
mut_pred = np.load(args.mut_pred_file)
|
| 73 |
+
mut_dev['mut_pred'] = mut_pred
|
| 74 |
+
|
| 75 |
+
# merge
|
| 76 |
+
dev = pd.merge(orig_dev[['idx','orig_seq','orig_pred']],
|
| 77 |
+
mut_dev[['idx','mut_seq','mut_pred']],
|
| 78 |
+
on='idx'
|
| 79 |
+
)
|
| 80 |
+
dev['diff'] = (dev['mut_pred'] - dev['orig_pred'])*(dev[['orig_pred','mut_pred']].max(axis=1))
|
| 81 |
+
dev['logOR'] = np.log2(dev['orig_pred']/(1-dev['orig_pred'])) - np.log2(dev['mut_pred']/(1-dev['mut_pred']))
|
| 82 |
+
dev.to_csv(os.path.join(args.save_file_dir,'mutations.tsv'),sep='\t')
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
SNP/example_mut_file.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0 30 31 G
|
| 2 |
+
23 52 53 T
|
| 3 |
+
104 14 15 C
|
| 4 |
+
125 22 23 A
|
| 5 |
+
240 8 8 A
|
| 6 |
+
325 10 11
|
SNP/examples/dev.tsv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TTTTTA TTTTAA TTTAAA TTAAAA TAAAAG AAAAGT AAAGTA AAGTAA AGTAAA GTAAAC TAAACA AAACAC AACACT ACACTG CACTGT ACTGTT CTGTTT TGTTTT GTTTTC TTTTCA TTTCAT TTCATT TCATTA CATTAG ATTAGG TTAGGG TAGGGC AGGGCC GGGCCA GGCCAA GCCAAG CCAAGC CAAGCT AAGCTA AGCTAA GCTAAT CTAATC TAATCC AATCCT ATCCTT TCCTTA CCTTAT CTTATT TTATTG TATTGA ATTGAG TTGAGA TGAGAA GAGAAT AGAATT GAATTT AATTTC ATTTCT TTTCTA TTCTAA TCTAAA CTAAAG TAAAGG AAAGGG AAGGGA AGGGAC GGGACA GGACAT GACATT ACATTA 0
|
| 2 |
+
CGCATT GCATTA CATTAA ATTAAT TTAATA TAATAG AATAGT ATAGTG TAGTGG AGTGGA GTGGAC TGGACT GGACTA GACTAG ACTAGG CTAGGG TAGGGG AGGGGC GGGGCA GGGCAG GGCAGG GCAGGG CAGGGC AGGGCT GGGCTG GGCTGG GCTGGA CTGGAT TGGATT GGATTT GATTTT ATTTTC TTTTCG TTTCGG TTCGGA TCGGAG CGGAGG GGAGGC GAGGCA AGGCAG GGCAGT GCAGTG CAGTGT AGTGTG GTGTGC TGTGCA GTGCAG TGCAGT GCAGTT CAGTTC AGTTCC GTTCCC TTCCCA TCCCAA CCCAAT CCAATA CAATAA AATAAC ATAACT TAACTA AACTAG ACTAGT CTAGTT TAGTTC AGTTCC 23
|
| 3 |
+
TTCATA TCATAA CATAAA ATAAAT TAAATT AAATTA AATTAC ATTACC TTACCC TACCCC ACCCCG CCCCGT CCCGTT CCGTTT CGTTTC GTTTCT TTTCTC TTCTCA TCTCAT CTCATA TCATAG CATAGT ATAGTT TAGTTC AGTTCT GTTCTT TTCTTT TCTTTA CTTTAT TTTATA TTATAG TATAGC ATAGCA TAGCAG AGCAGT GCAGTG CAGTGT AGTGTG GTGTGA TGTGAA GTGAAA TGAAAA GAAAAC AAAACA AAACAG AACAGA ACAGAC CAGACT AGACTA GACTAA ACTAAT CTAATG TAATGG AATGGA ATGGAC TGGACC GGACCC GACCCT ACCCTT CCCTTC CCTTCT CTTCTG TTCTGG TCTGGT CTGGTT 104
|
| 4 |
+
GAGATA AGATAA GATAAA ATAAAG TAAAGG AAAGGA AAGGAA AGGAAG GGAAGG GAAGGG AAGGGA AGGGAA GGGAAT GGAATC GAATCA AATCAG ATCAGT TCAGTA CAGTAC AGTACC GTACCA TACCAT ACCATC CCATCC CATCCA ATCCAG TCCAGA CCAGAA CAGAAG AGAAGC GAAGCA AAGCAA AGCAAT GCAATG CAATGA AATGAG ATGAGA TGAGAT GAGATG AGATGG GATGGA ATGGAG TGGAGG GGAGGG GAGGGC AGGGCA GGGCAG GGCAGC GCAGCA CAGCAG AGCAGG GCAGGG CAGGGA AGGGAG GGGAGG GGAGGA GAGGAG AGGAGA GGAGAG GAGAGA AGAGAA GAGAAA AGAAAG GAAAGA AAAGAC 125
|
| 5 |
+
GGTACA GTACAA TACAAA ACAAAA CAAAAG AAAAGA AAAGAC AAGACG AGACGA GACGAA ACGAAC CGAACA GAACAA AACAAC ACAACG CAACGC AACGCC ACGCCA CGCCAT GCCATC CCATCC CATCCC ATCCCC TCCCCG CCCCGT CCCGTC CCGTCG CGTCGT GTCGTC TCGTCG CGTCGA GTCGAA TCGAAT CGAATG GAATGG AATGGC ATGGCA TGGCAG GGCAGA GCAGAC CAGACA AGACAA GACAAG ACAAGT CAAGTA AAGTAA AGTAAC GTAACC TAACCA AACCAG ACCAGT CCAGTC CAGTCT AGTCTT GTCTTT TCTTTG CTTTGT TTTGTA TTGTAA TGTAAC GTAACG TAACGT AACGTA ACGTAG CGTAGT GTAGTG 240
|
| 6 |
+
GGAACT GAACTT AACTTA ACTTAA CTTAAA TTAAAn TAAAna AAAnan AAnanG AnanGG nanGGC anGGCC nGGCCG GGCCGG GCCGGC CCGGCT CGGCTG GGCTGT GCTGTT CTGTTT TGTTTC GTTTCG TTTCGG TTCGGC TCGGCG CGGCGG GGCGGC GCGGCC CGGCCG GGCCGC GCCGCG CCGCGG CGCGGG GCGGGA CGGGAT GGGATG GGATGC GATGCC ATGCCC TGCCCC GCCCCT CCCCTG CCCTGC CCTGCG CTGCGC TGCGCT GCGCTG CGCTGA GCTGAC CTGACC TGACCG GACCGC ACCGCC CCGCCA CGCCAG GCCAGG CCAGGG CAGGGG AGGGGC GGGGCA GGGCAG GGCAGG GCAGGT CAGGTG AGGTGC GGTGCC GTGCCC 325
|
SNP/mutate_seqs.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#### ::: mutate seqs ::: ####
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('../motif')
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import argparse
|
| 9 |
+
import motif_utils as utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def mutate(seq, start, end, target=None):
|
| 13 |
+
"""
|
| 14 |
+
Mutate input sequence at specified position.
|
| 15 |
+
|
| 16 |
+
If target is not None, returns the mutated seq. Otherwise, returns a numpy array with shape (4,1)
|
| 17 |
+
with all four mutated possibilities.
|
| 18 |
+
|
| 19 |
+
Arguments:
|
| 20 |
+
seq -- str, original sequence.
|
| 21 |
+
start -- int, starting index where nucleotide needs to be changed. Counting starts at zero.
|
| 22 |
+
end -- int, ending index where nucleotide needs to be changed. Counting starts at zero.
|
| 23 |
+
|
| 24 |
+
Keyword arguments:
|
| 25 |
+
target -- str, the target nucleotide(s) to be changed to (default: None).
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
mutated_seq -- str, mutated sequence.
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
assert end >= start and start >= 0 and end <= len(seq), "Wrong start and end index input."
|
| 32 |
+
|
| 33 |
+
if target is not None:
|
| 34 |
+
mutated_seq = seq[:start] + str(target) + seq[end:]
|
| 35 |
+
else:
|
| 36 |
+
mutated_seq = []
|
| 37 |
+
for n in ['A','T','G','C']:
|
| 38 |
+
m_seq = seq[:start] + str(n) + seq[end:]
|
| 39 |
+
mutated_seq.append(m_seq)
|
| 40 |
+
mutated_seq = np.asarray(mutated_seq)
|
| 41 |
+
return mutated_seq
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"seq_file",
|
| 47 |
+
type=str,
|
| 48 |
+
help="Path to input sequence+label .tsv file.",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"save_file_dir",
|
| 53 |
+
type=str,
|
| 54 |
+
help="Path to save the mutated seqs",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--mut_file",
|
| 59 |
+
default=None,
|
| 60 |
+
type=str,
|
| 61 |
+
help="Path to the file defining how each input seq should be mutated",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--k",
|
| 66 |
+
default=3,
|
| 67 |
+
type=int,
|
| 68 |
+
help="length of kmer for conversion of mutated seqs"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# TODO: add the conditions
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
os.makedirs(args.save_file_dir, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
mutated_dev = {'index':[],'seq':[]}
|
| 77 |
+
|
| 78 |
+
dev = pd.read_csv(args.seq_file,sep='\t',header=0)
|
| 79 |
+
dev.columns = ['sequence','label']
|
| 80 |
+
dev['seq'] = dev['sequence'].apply(utils.kmer2seq)
|
| 81 |
+
|
| 82 |
+
if args.mut_file is not None:
|
| 83 |
+
mut_file = pd.read_csv(args.mut_file, sep='\t',header=None)
|
| 84 |
+
mut_file = mut_file.fillna('')
|
| 85 |
+
mut_file.columns = ['idx','start', 'end', 'allele']
|
| 86 |
+
mut_file['idx'] = mut_file['idx'].astype(int)
|
| 87 |
+
mut_file['start'] = mut_file['start'].astype(int)
|
| 88 |
+
mut_file['end'] = mut_file['end'].astype(int)
|
| 89 |
+
dev_selected = dev.iloc[mut_file['idx'].tolist(),:].reset_index()
|
| 90 |
+
for i, row in dev_selected.iterrows():
|
| 91 |
+
seq = row['seq']
|
| 92 |
+
mut = mut_file.iloc[i]
|
| 93 |
+
mut_seq = mutate(seq, mut['start'], mut['end'], target = mut['allele'])
|
| 94 |
+
mut_seq = utils.seq2kmer(mut_seq, args.k)
|
| 95 |
+
mutated_dev['index'].append(mut['idx'])
|
| 96 |
+
mutated_dev['seq'].append(mut_seq)
|
| 97 |
+
else:
|
| 98 |
+
for i, row in dev.iterrows():
|
| 99 |
+
seq = row['seq']
|
| 100 |
+
for j in range(len(seq)):
|
| 101 |
+
mut_seq = mutate(seq, j, j+1)
|
| 102 |
+
mut_seq = [utils.seq2kmer(seq, args.k) for seq in mut_seq]
|
| 103 |
+
idx = [i] * 4
|
| 104 |
+
mutated_dev['index'].extend(idx)
|
| 105 |
+
mutated_dev['seq'].extend(mut_seq)
|
| 106 |
+
|
| 107 |
+
mutated_dev = pd.DataFrame.from_dict(mutated_dev)
|
| 108 |
+
mutated_dev = mutated_dev[['seq','index']]
|
| 109 |
+
mutated_dev.columns = ['sequence','index']
|
| 110 |
+
mutated_dev['label'] = 0
|
| 111 |
+
mutated_dev.iloc[0, mutated_dev.columns.get_loc('label')] = 1
|
| 112 |
+
mutated_dev = mutated_dev[['sequence','label','index']]
|
| 113 |
+
|
| 114 |
+
mutated_dev.to_csv(os.path.join(args.save_file_dir,'dev.tsv'),sep='\t',header=True, index=False)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|
examples/.Rhistory
ADDED
|
File without changes
|
examples/.run_pretrain.py.swp
ADDED
|
Binary file (1.02 kB). View file
|
|
|
examples/6mer_pretrain_emb/static_6mer_embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5422f25436f65a3cb50f5e3881ab1a4c0e3d417eb8fb11f485fc1f9b0ef0b04d
|
| 3 |
+
size 12598400
|
examples/6mer_pretrain_emb_20ways/static_6mer_embed_20ways.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e621f2367d58715c3defef6e0a504feed12e96a308da56f19383e68534e6b03
|
| 3 |
+
size 12598400
|
examples/6mer_pretrain_emb_adaptive/static_adaptive_embed.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:41de47985ee1cd6d29a98951beece1d79d7c48e6295e7701e7bfb46f06079705
|
| 3 |
+
size 12598400
|
examples/compute_result.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
import csv
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from sklearn.metrics import matthews_corrcoef, confusion_matrix, f1_score
|
| 6 |
+
|
| 7 |
+
def generate_pred(predict_results, i, slide, metric="max"):
|
| 8 |
+
|
| 9 |
+
results = predict_results[i*3:(i+1)*3]
|
| 10 |
+
|
| 11 |
+
if metric == "max":
|
| 12 |
+
pred = max(results)
|
| 13 |
+
elif metric == "mean":
|
| 14 |
+
pred = np.mean(results)
|
| 15 |
+
elif metric == "second-max":
|
| 16 |
+
pred = np.sort(results)[-2]
|
| 17 |
+
else:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
return pred
|
| 21 |
+
|
| 22 |
+
def Compute_scan(args):
|
| 23 |
+
predict_results = np.load(args.pred_path)
|
| 24 |
+
labels = np.load(args.label_path)
|
| 25 |
+
labels = list(labels.astype(int))
|
| 26 |
+
|
| 27 |
+
results = []
|
| 28 |
+
for i in range(len(labels)):
|
| 29 |
+
pred = generate_pred(predict_results, i, args.slide, args.metric)
|
| 30 |
+
|
| 31 |
+
if pred >= args.bound:
|
| 32 |
+
results.append(1)
|
| 33 |
+
else:
|
| 34 |
+
results.append(0)
|
| 35 |
+
a = set(results)
|
| 36 |
+
b = set(labels)
|
| 37 |
+
f1 = f1_score(y_true=labels, y_pred=results)
|
| 38 |
+
mcc = matthews_corrcoef(labels, results)
|
| 39 |
+
tn, fp, fn, tp = confusion_matrix(labels, results).ravel()
|
| 40 |
+
|
| 41 |
+
count = 0
|
| 42 |
+
for i in range(len(results)):
|
| 43 |
+
if results[i] == labels[i]:
|
| 44 |
+
count+=1
|
| 45 |
+
|
| 46 |
+
print("number of examples: " + str(len(labels)))
|
| 47 |
+
print("number of positive examples: " + str(sum(labels)))
|
| 48 |
+
print("number of negative examples: " + str(len(labels)-sum(labels)))
|
| 49 |
+
print("f1: ", str(f1))
|
| 50 |
+
print("mcc: " + str(mcc))
|
| 51 |
+
print("accuracy: " + str(float(count)/len(results)))
|
| 52 |
+
print("tn:" + str(tn))
|
| 53 |
+
print("fp:" + str(fp))
|
| 54 |
+
print("fn:" + str(fn))
|
| 55 |
+
print("tp:" + str(tp))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def Compute_mouse(args):
|
| 59 |
+
result_file = open(args.pred_path, "r")
|
| 60 |
+
results = result_file.readlines()
|
| 61 |
+
print(len(results))
|
| 62 |
+
|
| 63 |
+
all_preds = []
|
| 64 |
+
current_preds = []
|
| 65 |
+
for result in results:
|
| 66 |
+
scores = result.split()
|
| 67 |
+
scores = [scores[0], float(scores[1]), float(scores[2]), float(scores[3]), float(scores[4]), float(scores[5]), float(scores[6]), float(scores[7])]
|
| 68 |
+
if current_preds == [] or scores[0] == current_preds[0][0]:
|
| 69 |
+
current_preds.append(scores)
|
| 70 |
+
else:
|
| 71 |
+
all_preds.append(current_preds)
|
| 72 |
+
current_preds = []
|
| 73 |
+
current_preds.append(scores)
|
| 74 |
+
all_preds.append(current_preds)
|
| 75 |
+
|
| 76 |
+
print("Number of task: %d" % len(all_preds))
|
| 77 |
+
|
| 78 |
+
def get_acc(val):
|
| 79 |
+
return val[1]
|
| 80 |
+
|
| 81 |
+
def get_auc(val):
|
| 82 |
+
return val[2]
|
| 83 |
+
|
| 84 |
+
tasks = []
|
| 85 |
+
acc = []
|
| 86 |
+
auc = []
|
| 87 |
+
aupr = []
|
| 88 |
+
f1 = []
|
| 89 |
+
mcc = []
|
| 90 |
+
precision = []
|
| 91 |
+
recall = []
|
| 92 |
+
|
| 93 |
+
for pred in all_preds:
|
| 94 |
+
if len(pred) < 10 :
|
| 95 |
+
print("Short %s : %d" % (pred[0][0], len(pred)))
|
| 96 |
+
|
| 97 |
+
if args.index == "acc":
|
| 98 |
+
pred.sort(key=get_acc)
|
| 99 |
+
elif args.index == "auc":
|
| 100 |
+
pred.sort(key=get_auc)
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError()
|
| 103 |
+
|
| 104 |
+
BEST = -1
|
| 105 |
+
for i in range(len(pred)):
|
| 106 |
+
if pred[i][1] == pred[-1][1] and pred[i][2] > pred[-1][2]:
|
| 107 |
+
BEST = deepcopy(i)
|
| 108 |
+
tasks.append(pred[0][0])
|
| 109 |
+
|
| 110 |
+
best_pred = pred[BEST]
|
| 111 |
+
acc.append(best_pred[1])
|
| 112 |
+
auc.append(best_pred[2])
|
| 113 |
+
aupr.append(best_pred[3])
|
| 114 |
+
f1.append(best_pred[4])
|
| 115 |
+
mcc.append(best_pred[5])
|
| 116 |
+
precision.append(best_pred[6])
|
| 117 |
+
recall.append(best_pred[7])
|
| 118 |
+
|
| 119 |
+
acc_ave = np.mean(acc)
|
| 120 |
+
auc_ave = np.mean(auc)
|
| 121 |
+
aupr_ave = np.mean(aupr)
|
| 122 |
+
f1_ave = np.mean(f1)
|
| 123 |
+
mcc_ave = np.mean(mcc)
|
| 124 |
+
precision_ave = np.mean(precision)
|
| 125 |
+
recall_ave = np.mean(recall)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
print("acc: " + str(acc_ave))
|
| 129 |
+
print("auc: " + str(auc_ave))
|
| 130 |
+
print("aupr: " + str(aupr_ave))
|
| 131 |
+
print("f1: ", str(f1_ave))
|
| 132 |
+
print("mcc: " + str(mcc_ave))
|
| 133 |
+
print("precision: ", str(precision_ave))
|
| 134 |
+
print("recall: " + str(recall_ave))
|
| 135 |
+
|
| 136 |
+
# find and print the tasks whose results are worst
|
| 137 |
+
ranks = np.argsort(auc)[:args.num_worst]
|
| 138 |
+
print("Top %d worst tasks: " % (args.num_worst))
|
| 139 |
+
for i in ranks:
|
| 140 |
+
print(tasks[i] + " %3f %3f" % (acc[i], auc[i]))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def Compute_690(args):
|
| 146 |
+
result_file = open(args.pred_path, "r")
|
| 147 |
+
results = result_file.readlines()
|
| 148 |
+
|
| 149 |
+
preds = []
|
| 150 |
+
|
| 151 |
+
for result in results:
|
| 152 |
+
scores = result.split()
|
| 153 |
+
preds.append([scores[0], float(scores[1]), float(scores[2]), float(scores[4]), float(scores[5])])
|
| 154 |
+
|
| 155 |
+
num_results = args.num_results
|
| 156 |
+
|
| 157 |
+
num_example = int(len(preds)/num_results)
|
| 158 |
+
print("Num of tasks: %d" % num_example)
|
| 159 |
+
|
| 160 |
+
def get_acc(val):
|
| 161 |
+
return val[1]
|
| 162 |
+
|
| 163 |
+
def get_auc(val):
|
| 164 |
+
return val[2]
|
| 165 |
+
|
| 166 |
+
def get_f1(val):
|
| 167 |
+
return val[3]
|
| 168 |
+
|
| 169 |
+
def get_mcc(val):
|
| 170 |
+
return val[4]
|
| 171 |
+
|
| 172 |
+
tasks = []
|
| 173 |
+
acc = []
|
| 174 |
+
auc = []
|
| 175 |
+
f1 = []
|
| 176 |
+
mcc = []
|
| 177 |
+
|
| 178 |
+
for i in range(num_example):
|
| 179 |
+
tasks.append(preds[i*num_results][0])
|
| 180 |
+
|
| 181 |
+
current_preds = preds[i*num_results:(i+1)*num_results]
|
| 182 |
+
if args.index == "acc":
|
| 183 |
+
current_preds.sort(key=get_acc)
|
| 184 |
+
elif args.index == "auc":
|
| 185 |
+
current_preds.sort(key=get_auc)
|
| 186 |
+
elif args.index == "f1":
|
| 187 |
+
current_preds.sort(key=get_f1)
|
| 188 |
+
elif args.index == "mcc":
|
| 189 |
+
current_preds.sort(key=get_mcc)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError()
|
| 192 |
+
best_pred = current_preds[-1]
|
| 193 |
+
acc.append(best_pred[1])
|
| 194 |
+
auc.append(best_pred[2])
|
| 195 |
+
f1.append(best_pred[3])
|
| 196 |
+
mcc.append(best_pred[4])
|
| 197 |
+
|
| 198 |
+
# calculate and print the average scores
|
| 199 |
+
acc_ave = np.mean(acc)
|
| 200 |
+
auc_ave = np.mean(auc)
|
| 201 |
+
f1_ave = np.mean(f1)
|
| 202 |
+
mcc_ave = np.mean(mcc)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
print("acc: " + str(acc_ave))
|
| 206 |
+
print("auc: " + str(auc_ave))
|
| 207 |
+
print("f1: ", str(f1_ave))
|
| 208 |
+
print("mcc: " + str(mcc_ave))
|
| 209 |
+
|
| 210 |
+
# find and print the tasks whose results are worst
|
| 211 |
+
ranks = np.argsort(auc)[:args.num_worst]
|
| 212 |
+
print("Top %d worst tasks: " % (args.num_worst))
|
| 213 |
+
for i in ranks:
|
| 214 |
+
print(tasks[i] + " %3f %3f" % (acc[i], auc[i]))
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def main():
|
| 219 |
+
parser = argparse.ArgumentParser()
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--bound",
|
| 222 |
+
default=0.5,
|
| 223 |
+
type=float,
|
| 224 |
+
help="K-mer",
|
| 225 |
+
)
|
| 226 |
+
parser.add_argument(
|
| 227 |
+
"--pred_path",
|
| 228 |
+
default=None,
|
| 229 |
+
type=str,
|
| 230 |
+
help="The path of the predicted result",
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--label_path",
|
| 234 |
+
default=None,
|
| 235 |
+
type=str,
|
| 236 |
+
help="The path of the label",
|
| 237 |
+
)
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--metric",
|
| 240 |
+
default="max",
|
| 241 |
+
type=str,
|
| 242 |
+
help="The metric of computing predited result (scan)",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--slide",
|
| 246 |
+
default=3,
|
| 247 |
+
type=int,
|
| 248 |
+
help="How many 500s to use for the predictes result of 1000 (scan)",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--task",
|
| 252 |
+
default="scan",
|
| 253 |
+
type=str,
|
| 254 |
+
help="Which task to compute result",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--index",
|
| 258 |
+
default="acc",
|
| 259 |
+
type=str,
|
| 260 |
+
help="Which index to sort result (690)",
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--num_results",
|
| 264 |
+
default="10",
|
| 265 |
+
type=int,
|
| 266 |
+
help="Number of results for each task (690)",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--num_worst",
|
| 270 |
+
default="10",
|
| 271 |
+
type=int,
|
| 272 |
+
help="Number of worst tasks to print out (690)",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
if args.task == "scan":
|
| 278 |
+
Compute_scan(args)
|
| 279 |
+
elif args.task == "690":
|
| 280 |
+
Compute_690(args)
|
| 281 |
+
elif args.task == "mouse":
|
| 282 |
+
Compute_mouse(args)
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError()
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|
examples/data_process_template/.process_pretrain_data_multi.py.swp
ADDED
|
Binary file (4.1 kB). View file
|
|
|
examples/data_process_template/process_690.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
from process_pretrain_data import get_kmer_sentence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def Process(args):
|
| 11 |
+
path = args.file_path
|
| 12 |
+
all_folders = os.listdir(path)
|
| 13 |
+
|
| 14 |
+
count = 0
|
| 15 |
+
|
| 16 |
+
for folder in all_folders:
|
| 17 |
+
# load data
|
| 18 |
+
train_seq_path = os.path.join(args.file_path, folder, "train", "sequences_alph.npy")
|
| 19 |
+
test_seq_path = os.path.join(args.file_path, folder, "test", "sequences_alph.npy")
|
| 20 |
+
train_lab_path = os.path.join(args.file_path, folder, "train", "targets.npy")
|
| 21 |
+
test_lab_path = os.path.join(args.file_path, folder, "test", "targets.npy")
|
| 22 |
+
train_sequences = np.load(train_seq_path)
|
| 23 |
+
test_sequences = np.load(test_seq_path)
|
| 24 |
+
train_labels = np.load(train_lab_path)
|
| 25 |
+
test_labels = np.load(test_lab_path)
|
| 26 |
+
|
| 27 |
+
train_sequences = train_sequences.reshape(train_sequences.shape[0],1)
|
| 28 |
+
test_sequences = test_sequences.reshape(test_sequences.shape[0],1)
|
| 29 |
+
train_labels = train_labels.reshape(train_labels.shape[0],1)
|
| 30 |
+
test_labels = test_labels.reshape(test_labels.shape[0],1)
|
| 31 |
+
|
| 32 |
+
# concat sequence and labels together
|
| 33 |
+
trains = list(np.concatenate((train_sequences, train_labels), axis=1))
|
| 34 |
+
tests = list(np.concatenate((test_sequences, test_labels), axis=1))
|
| 35 |
+
|
| 36 |
+
random.seed(24)
|
| 37 |
+
random.shuffle(trains)
|
| 38 |
+
random.shuffle(trains)
|
| 39 |
+
random.shuffle(tests)
|
| 40 |
+
random.shuffle(tests)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# make output path
|
| 44 |
+
output_path = os.path.join(args.output_path, str(args.kmer), folder)
|
| 45 |
+
if not os.path.exists(output_path):
|
| 46 |
+
os.makedirs(output_path)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# write files
|
| 51 |
+
f_train = open(os.path.join(output_path, "train.tsv"), 'wt')
|
| 52 |
+
tsv_train = csv.writer(f_train, delimiter='\t')
|
| 53 |
+
tsv_train.writerow(["sequence", "label"])
|
| 54 |
+
for i in range(len(trains)):
|
| 55 |
+
sentence = get_kmer_sentence(trains[i][0].decode("utf-8"), args.kmer)
|
| 56 |
+
tsv_train.writerow([sentence, int(trains[i][1])])
|
| 57 |
+
|
| 58 |
+
f_dev = open(os.path.join(output_path, "dev.tsv"), 'wt')
|
| 59 |
+
tsv_dev = csv.writer(f_dev, delimiter='\t')
|
| 60 |
+
tsv_dev.writerow(["sequence", "label"])
|
| 61 |
+
for i in range(len(tests)):
|
| 62 |
+
sentence = get_kmer_sentence(tests[i][0].decode("utf-8"), args.kmer)
|
| 63 |
+
tsv_dev.writerow([sentence, int(tests[i][1])])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
count += 1
|
| 67 |
+
print("Finish %s folders" % (count))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
parser = argparse.ArgumentParser()
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--kmer",
|
| 78 |
+
default=1,
|
| 79 |
+
type=int,
|
| 80 |
+
help="K-mer",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--file_path",
|
| 84 |
+
default=None,
|
| 85 |
+
type=str,
|
| 86 |
+
help="The path of the file to be processed",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--output_path",
|
| 90 |
+
default=None,
|
| 91 |
+
type=str,
|
| 92 |
+
help="The path of the processed data",
|
| 93 |
+
)
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
Process(args)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
examples/data_process_template/process_csv.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
import random
|
| 6 |
+
from process_pretrain_data import get_kmer_sentence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
max_length = 0
|
| 10 |
+
|
| 11 |
+
def Process_pair(args):
|
| 12 |
+
random.seed(42)
|
| 13 |
+
|
| 14 |
+
root_path = args.file_path.split('/')[-1]
|
| 15 |
+
train_seq1_file = open(args.file_path+"/"+root_path+"_enhancer.fasta", "r")
|
| 16 |
+
train_seq2_file = open(args.file_path+"/"+root_path+"_promoter.fasta", "r")
|
| 17 |
+
train_label_file = open(args.file_path+"/"+root_path+"_label.txt", "r")
|
| 18 |
+
test_seq1_file = open(args.file_path+"/"+root_path+"_enhancer_test.fasta", "r")
|
| 19 |
+
test_seq2_file = open(args.file_path+"/"+root_path+"_promoter_test.fasta", "r")
|
| 20 |
+
test_label_file = open(args.file_path+"/"+root_path+"_label_test.txt", "r")
|
| 21 |
+
|
| 22 |
+
train_seq1 = train_seq1_file.readlines()
|
| 23 |
+
train_seq2 = train_seq2_file.readlines()
|
| 24 |
+
train_label = train_label_file.readlines()
|
| 25 |
+
test_seq1 = test_seq1_file.readlines()
|
| 26 |
+
test_seq2 = test_seq2_file.readlines()
|
| 27 |
+
test_label = test_label_file.readlines()
|
| 28 |
+
|
| 29 |
+
train_lines = []
|
| 30 |
+
test_lines = []
|
| 31 |
+
for i in range(len(train_label)):
|
| 32 |
+
train_lines.append([train_seq1[2*i+1], train_seq2[2*i+1], train_label[i]])
|
| 33 |
+
for i in range(len(test_label)):
|
| 34 |
+
test_lines.append([test_seq1[2*i+1], test_seq2[2*i+1], test_label[i]])
|
| 35 |
+
|
| 36 |
+
random.shuffle(train_lines)
|
| 37 |
+
|
| 38 |
+
if args.dev:
|
| 39 |
+
num_dev = int(len(train_lines)/10)
|
| 40 |
+
dev_lines = train_lines[:num_dev]
|
| 41 |
+
train_lines = train_lines[num_dev:]
|
| 42 |
+
|
| 43 |
+
output_path = make_path(args)
|
| 44 |
+
|
| 45 |
+
suffix = '.csv' if args.csv else '.tsv'
|
| 46 |
+
delimiter = ',' if args.csv else '\t'
|
| 47 |
+
|
| 48 |
+
f_train = open(os.path.join(output_path, "train" + suffix), 'wt')
|
| 49 |
+
train_w = csv.writer(f_train, delimiter=delimiter)
|
| 50 |
+
train_w.writerow(["seq1", "seq2", "label"])
|
| 51 |
+
if args.dev:
|
| 52 |
+
f_dev = open(os.path.join(output_path, "dev" + suffix), 'wt')
|
| 53 |
+
dev_w = csv.writer(f_dev, delimiter=delimiter)
|
| 54 |
+
dev_w.writerow(["seq1", "seq2", "label"])
|
| 55 |
+
os.makedirs(os.path.join(output_path, "test"))
|
| 56 |
+
f_test = open(os.path.join(output_path, "test", "dev" + suffix), 'wt')
|
| 57 |
+
test_w = csv.writer(f_test, delimiter=delimiter)
|
| 58 |
+
test_w.writerow(["seq1", "seq2", "label"])
|
| 59 |
+
else:
|
| 60 |
+
f_test = open(os.path.join(output_path, "dev" + suffix), 'wt')
|
| 61 |
+
test_w = csv.writer(f_test, delimiter=delimiter)
|
| 62 |
+
test_w.writerow(["seq1", "seq2", "label"])
|
| 63 |
+
|
| 64 |
+
def write_file_pair(lines, writer, seq1_index=0, seq2_index=1, label_index=2):
|
| 65 |
+
for line in lines:
|
| 66 |
+
seq1 = get_kmer_sentence(line[seq1_index], kmer=args.kmer, stride=args.stride)
|
| 67 |
+
seq2 = get_kmer_sentence(line[seq2_index], kmer=args.kmer, stride=args.stride)
|
| 68 |
+
writer.writerow([seq1, seq2, str(int(line[label_index]))])
|
| 69 |
+
|
| 70 |
+
write_file_pair(train_lines, train_w)
|
| 71 |
+
write_file_pair(test_lines, test_w)
|
| 72 |
+
|
| 73 |
+
if args.dev:
|
| 74 |
+
write_file_pair(dev_lines, dev_w)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def make_path(args):
|
| 78 |
+
output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(args.kmer))
|
| 79 |
+
if not os.path.exists(output_path):
|
| 80 |
+
os.makedirs(output_path)
|
| 81 |
+
return output_path
|
| 82 |
+
|
| 83 |
+
def write_file(lines, writer, seq_index=2, label_index=3, kmer=6, stride=1):
|
| 84 |
+
global max_length
|
| 85 |
+
for line in lines:
|
| 86 |
+
sentence = get_kmer_sentence(line[seq_index], kmer=kmer, stride=stride)
|
| 87 |
+
if len(sentence.split()) > max_length:
|
| 88 |
+
max_length = len(sentence.split())
|
| 89 |
+
if label_index == -100:
|
| 90 |
+
writer.writerow([sentence, str(0)])
|
| 91 |
+
else:
|
| 92 |
+
writer.writerow([sentence, str(line[label_index])])
|
| 93 |
+
|
| 94 |
+
def Process(args):
|
| 95 |
+
random.seed(24)
|
| 96 |
+
|
| 97 |
+
train = os.path.join(args.file_path, "train.csv")
|
| 98 |
+
test = os.path.join(args.file_path, "test.csv")
|
| 99 |
+
train_file = open(train, "r", encoding="utf-8-sig")
|
| 100 |
+
test_file = open(test, "r", encoding="utf-8-sig")
|
| 101 |
+
|
| 102 |
+
train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
|
| 103 |
+
test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
|
| 104 |
+
|
| 105 |
+
random.shuffle(train_lines)
|
| 106 |
+
random.shuffle(test_lines)
|
| 107 |
+
|
| 108 |
+
if args.dev:
|
| 109 |
+
num_dev = int(len(train_lines)/9)
|
| 110 |
+
dev_lines = train_lines[:num_dev]
|
| 111 |
+
train_lines = train_lines[num_dev:]
|
| 112 |
+
|
| 113 |
+
print(train_lines[0])
|
| 114 |
+
|
| 115 |
+
output_path = make_path(args)
|
| 116 |
+
|
| 117 |
+
suffix = '.csv' if args.csv else '.tsv'
|
| 118 |
+
delimiter = ',' if args.csv else '\t'
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
f_train = open(os.path.join(output_path, "train"+suffix), 'wt')
|
| 122 |
+
train_w = csv.writer(f_train, delimiter=delimiter)
|
| 123 |
+
train_w.writerow(["sentence", "label"])
|
| 124 |
+
if args.dev:
|
| 125 |
+
f_dev = open(os.path.join(output_path, "dev"+suffix), 'wt')
|
| 126 |
+
dev_w = csv.writer(f_dev, delimiter=delimiter)
|
| 127 |
+
dev_w.writerow(["sentence", "label"])
|
| 128 |
+
f_test = open(os.path.join(output_path, "test"+suffix), 'wt')
|
| 129 |
+
test_w = csv.writer(f_test, delimiter=delimiter)
|
| 130 |
+
test_w.writerow(["sentence", "label"])
|
| 131 |
+
else:
|
| 132 |
+
f_test = open(os.path.join(output_path, "dev"+suffix), 'wt')
|
| 133 |
+
test_w = csv.writer(f_test, delimiter=delimiter)
|
| 134 |
+
test_w.writerow(["sentence", "label"])
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
write_file(train_lines, train_w, args.seq_index, args.label_index)
|
| 138 |
+
write_file(test_lines, test_w, args.seq_index, args.label_index)
|
| 139 |
+
|
| 140 |
+
if args.dev:
|
| 141 |
+
write_file(dev_lines, dev_w)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
print("max length: %d" % (max_length))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def Process_UCE(args):
|
| 148 |
+
len_count = {}
|
| 149 |
+
|
| 150 |
+
line2index = {}
|
| 151 |
+
|
| 152 |
+
pred_file = open(args.file_path, "r", encoding="utf-8-sig")
|
| 153 |
+
pred_lines = list(csv.reader(pred_file, delimiter=",", quotechar=None))[1:]
|
| 154 |
+
|
| 155 |
+
suffix = '.csv' if args.csv else '.tsv'
|
| 156 |
+
delimiter = ',' if args.csv else '\t'
|
| 157 |
+
|
| 158 |
+
f_pred = open(os.path.join(args.output_path, "dev"+suffix), 'wt')
|
| 159 |
+
pred_w = csv.writer(f_pred, delimiter=delimiter)
|
| 160 |
+
pred_w.writerow(["sentence", "label"])
|
| 161 |
+
|
| 162 |
+
index = 1
|
| 163 |
+
line_num = 0
|
| 164 |
+
for line in pred_lines:
|
| 165 |
+
len_count[len(line[8])] = len_count.get(len(line[8]), 0) + 1
|
| 166 |
+
len_count[len(line[-2])] = len_count.get(len(line[-2]), 0) + 1
|
| 167 |
+
|
| 168 |
+
cur_index = [index, index+1]
|
| 169 |
+
ref = get_kmer_sentence(line[8], args.kmer, args.stride)
|
| 170 |
+
pred_w.writerow([ref, 0])
|
| 171 |
+
|
| 172 |
+
mut1 = get_kmer_sentence(line[-2], args.kmer, args.stride)
|
| 173 |
+
pred_w.writerow([mut1, 0])
|
| 174 |
+
|
| 175 |
+
index += 2
|
| 176 |
+
|
| 177 |
+
if line[-2] != line[-1]:
|
| 178 |
+
len_count[len(line[-1])] = len_count.get(len(line[-1]), 0) + 1
|
| 179 |
+
mut2 = get_kmer_sentence(line[-1], args.kmer, args.stride)
|
| 180 |
+
pred_w.writerow([mut2, 0])
|
| 181 |
+
cur_index.append(index)
|
| 182 |
+
index += 1
|
| 183 |
+
|
| 184 |
+
line2index[line_num] = cur_index
|
| 185 |
+
line_num += 1
|
| 186 |
+
|
| 187 |
+
with open(os.path.join(args.output_path, "line2index.json"), "w") as f:
|
| 188 |
+
json.dump(line2index, f)
|
| 189 |
+
with open(os.path.join(args.output_path, "lencount.json"), "w") as f:
|
| 190 |
+
json.dump(len_count, f)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def Process_Virus(args):
|
| 194 |
+
file_path = args.file_path
|
| 195 |
+
|
| 196 |
+
all_files = os.listdir(file_path)
|
| 197 |
+
all_files = [f for f in all_files if not f.startswith("unclass")]
|
| 198 |
+
all_lines = []
|
| 199 |
+
for i, f in enumerate(all_files):
|
| 200 |
+
f_dir = os.path.join(file_path, f)
|
| 201 |
+
cur_file = open(f_dir, "r", encoding="utf-8-sig")
|
| 202 |
+
cur_lines = list(csv.reader(cur_file, delimiter=",", quotechar=None))[1:]
|
| 203 |
+
all_lines.extend(cur_lines)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
suffix = '.csv' if args.csv else '.tsv'
|
| 207 |
+
delimiter = ',' if args.csv else '\t'
|
| 208 |
+
|
| 209 |
+
f_pred = open(os.path.join(args.output_path, "dev"+suffix), 'wt')
|
| 210 |
+
pred_w = csv.writer(f_pred, delimiter=delimiter)
|
| 211 |
+
pred_w.writerow(["sentence", "label"])
|
| 212 |
+
|
| 213 |
+
index = 1
|
| 214 |
+
line_num = 0
|
| 215 |
+
for line in pred_lines:
|
| 216 |
+
cur_index = [index, index+1]
|
| 217 |
+
ref = get_kmer_sentence(line[8], args.kmer, args.stride)
|
| 218 |
+
pred_w.writerow([ref, 0])
|
| 219 |
+
|
| 220 |
+
mut1 = get_kmer_sentence(line[-2], args.kmer, args.stride)
|
| 221 |
+
pred_w.writerow([mut1, 0])
|
| 222 |
+
|
| 223 |
+
index += 2
|
| 224 |
+
|
| 225 |
+
if line[-2] != line[-1]:
|
| 226 |
+
len_count[len(line[-1])] = len_count.get(len(line[-1]), 0) + 1
|
| 227 |
+
mut2 = get_kmer_sentence(line[-1], args.kmer, args.stride)
|
| 228 |
+
pred_w.writerow([mut2, 0])
|
| 229 |
+
cur_index.append(index)
|
| 230 |
+
index += 1
|
| 231 |
+
|
| 232 |
+
line2index[line_num] = cur_index
|
| 233 |
+
line_num += 1
|
| 234 |
+
|
| 235 |
+
with open(os.path.join(args.output_path, "line2index.json"), "w") as f:
|
| 236 |
+
json.dump(line2index, f)
|
| 237 |
+
with open(os.path.join(args.output_path, "lencount.json"), "w") as f:
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def main():
|
| 243 |
+
parser = argparse.ArgumentParser()
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--kmer",
|
| 246 |
+
default=1,
|
| 247 |
+
type=int,
|
| 248 |
+
help="K-mer",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--stride",
|
| 252 |
+
default=1,
|
| 253 |
+
type=int,
|
| 254 |
+
help="stride in getting kmer sequence",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--file_path",
|
| 258 |
+
default=None,
|
| 259 |
+
type=str,
|
| 260 |
+
help="The path of the file to be processed",
|
| 261 |
+
)
|
| 262 |
+
parser.add_argument(
|
| 263 |
+
"--output_path",
|
| 264 |
+
default=None,
|
| 265 |
+
type=str,
|
| 266 |
+
help="The path of the processed data",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--dev",
|
| 270 |
+
action="store_true",
|
| 271 |
+
help="Use this flag to split data as (8:1:1), else (9:1)",
|
| 272 |
+
)
|
| 273 |
+
parser.add_argument(
|
| 274 |
+
"--csv",
|
| 275 |
+
action="store_true",
|
| 276 |
+
help="if output csv file or not, if not, output tsv",
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--pair",
|
| 280 |
+
action="store_true",
|
| 281 |
+
help="Use this flag to split data as (8:1:1), else (9:1)",
|
| 282 |
+
)
|
| 283 |
+
parser.add_argument(
|
| 284 |
+
"--uce",
|
| 285 |
+
action="store_true",
|
| 286 |
+
help="Use this flag to split data as (8:1:1), else (9:1)",
|
| 287 |
+
)
|
| 288 |
+
parser.add_argument(
|
| 289 |
+
"--seq_index",
|
| 290 |
+
default=2,
|
| 291 |
+
type=int,
|
| 292 |
+
help="index of seq in the original csv file",
|
| 293 |
+
)
|
| 294 |
+
parser.add_argument(
|
| 295 |
+
"--label_index",
|
| 296 |
+
default=3,
|
| 297 |
+
type=int,
|
| 298 |
+
help="index of label in the original csv file",
|
| 299 |
+
)
|
| 300 |
+
args = parser.parse_args()
|
| 301 |
+
|
| 302 |
+
if args.pair:
|
| 303 |
+
Process_pair(args)
|
| 304 |
+
elif args.uce:
|
| 305 |
+
Process_UCE(args)
|
| 306 |
+
else:
|
| 307 |
+
Process(args)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
main()
|
examples/data_process_template/process_finetune_data.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from process_pretrain_data import get_kmer_sentence
|
| 7 |
+
|
| 8 |
+
max_length = 0
|
| 9 |
+
|
| 10 |
+
def write_file(lines, path, kmer, head=True, seq_index=0, label_index=1):
|
| 11 |
+
with open(path, 'wt') as f:
|
| 12 |
+
tsv_w = csv.writer(f, delimiter='\t')
|
| 13 |
+
if head:
|
| 14 |
+
tsv_w.writerow(["setence", "label"])
|
| 15 |
+
for line in lines:
|
| 16 |
+
if kmer == 0:
|
| 17 |
+
sentence = str(line[seq_index])
|
| 18 |
+
else:
|
| 19 |
+
sentence = str(get_kmer_sentence("".join(line[seq_index].split()), kmer))
|
| 20 |
+
if label_index == None:
|
| 21 |
+
label = "0"
|
| 22 |
+
else:
|
| 23 |
+
label = str(line[label_index])
|
| 24 |
+
tsv_w.writerow([sentence, label])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def Shuffle(args):
|
| 28 |
+
old_file = open(args.file_path, "r", encoding="utf-8-sig")
|
| 29 |
+
old_lines = list(csv.reader(old_file, delimiter="\t", quotechar=None))[1:]
|
| 30 |
+
random.shuffle(old_lines)
|
| 31 |
+
|
| 32 |
+
write_file(old_lines, args.file_path, 0)
|
| 33 |
+
|
| 34 |
+
def Find_train(args):
|
| 35 |
+
random.seed(args.seed)
|
| 36 |
+
|
| 37 |
+
tata = args.file_path + "/TATA_249to50.tsv"
|
| 38 |
+
notata = args.file_path + "/noTATA_249to50.tsv"
|
| 39 |
+
tata_file = open(tata, "r", encoding="utf-8-sig")
|
| 40 |
+
notata_file = open(notata, "r", encoding="utf-8-sig")
|
| 41 |
+
tata_lines = list(csv.reader(tata_file, delimiter="\t", quotechar=None))[1:]
|
| 42 |
+
notata_lines = list(csv.reader(notata_file, delimiter="\t", quotechar=None))[1:]
|
| 43 |
+
|
| 44 |
+
tata_test = args.file_path + "/tata_test.tsv"
|
| 45 |
+
notata_test = args.file_path + "/notata_test.tsv"
|
| 46 |
+
tata_test_file = open(tata_test, "r", encoding="utf-8-sig")
|
| 47 |
+
notata_test_file = open(notata_test, "r", encoding="utf-8-sig")
|
| 48 |
+
tata_test_lines = list(csv.reader(tata_test_file, delimiter="\t", quotechar=None))[1:]
|
| 49 |
+
notata_test_lines = list(csv.reader(notata_test_file, delimiter="\t", quotechar=None))[1:]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
train_lines = []
|
| 53 |
+
|
| 54 |
+
for line in tata_lines:
|
| 55 |
+
if [line[0], line[1]] not in tata_test_lines:
|
| 56 |
+
train_lines.append([line[0], line[1]])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
for line in notata_lines:
|
| 60 |
+
if [line[0], line[1]] not in notata_test_lines:
|
| 61 |
+
train_lines.append([line[0], line[1]])
|
| 62 |
+
|
| 63 |
+
random.shuffle(train_lines)
|
| 64 |
+
random.shuffle(train_lines)
|
| 65 |
+
|
| 66 |
+
# num_dev = int(len(train_lines)/9.0)
|
| 67 |
+
# dev_lines = train_lines[:num_dev]
|
| 68 |
+
# train_lines = train_lines[num_dev:]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
write_file(train_lines, args.file_path+"/train.tsv", args.kmer, head=False)
|
| 72 |
+
# write_file(dev_lines, args.file_path+"/dev.tsv", args.kmer)
|
| 73 |
+
|
| 74 |
+
for kmer in range(3,7):
|
| 75 |
+
root_path = os.path.join(args.file_path, str(kmer))
|
| 76 |
+
if not os.path.exists(root_path):
|
| 77 |
+
os.makedirs(root_path)
|
| 78 |
+
|
| 79 |
+
train_file = open(os.path.join(args.file_path,"train.tsv"), "r", encoding="utf-8-sig")
|
| 80 |
+
lines = list(csv.reader(train_file, delimiter="\t", quotechar=None))
|
| 81 |
+
train_path = os.path.join(root_path,"train.tsv")
|
| 82 |
+
|
| 83 |
+
write_file(lines, train_path, kmer)
|
| 84 |
+
|
| 85 |
+
tata_path = os.path.join(root_path, "tata")
|
| 86 |
+
notata_path = os.path.join(root_path, "notata")
|
| 87 |
+
os.makedirs(tata_path)
|
| 88 |
+
os.makedirs(notata_path)
|
| 89 |
+
|
| 90 |
+
dev_lines = tata_test_lines+notata_test_lines
|
| 91 |
+
dev_path = os.path.join(root_path,"dev.tsv")
|
| 92 |
+
|
| 93 |
+
write_file(tata_test_lines, os.path.join(tata_path, "dev.tsv"), kmer)
|
| 94 |
+
write_file(notata_test_lines, os.path.join(notata_path, "dev.tsv"), kmer)
|
| 95 |
+
write_file(dev_lines, dev_path, kmer)
|
| 96 |
+
|
| 97 |
+
def Process_1000(args):
|
| 98 |
+
random.seed(args.seed)
|
| 99 |
+
|
| 100 |
+
tata_train = args.file_path + "TATA_scan_train.csv"
|
| 101 |
+
notata_train = args.file_path + "noTATA_scan_train.csv"
|
| 102 |
+
tata_train_file = open(tata_train, "r", encoding="utf-8-sig")
|
| 103 |
+
notata_train_file = open(notata_train, "r", encoding="utf-8-sig")
|
| 104 |
+
tata_train_lines = list(csv.reader(tata_train_file, delimiter=",", quotechar=None))[1:]
|
| 105 |
+
notata_train_lines = list(csv.reader(notata_train_file, delimiter=",", quotechar=None))[1:]
|
| 106 |
+
|
| 107 |
+
tata_test = args.file_path + "/TATA_scan_test.csv"
|
| 108 |
+
notata_test = args.file_path + "/noTATA_scan_test.csv"
|
| 109 |
+
tata_test_file = open(tata_test, "r", encoding="utf-8-sig")
|
| 110 |
+
notata_test_file = open(notata_test, "r", encoding="utf-8-sig")
|
| 111 |
+
tata_test_lines = list(csv.reader(tata_test_file, delimiter=",", quotechar=None))[1:]
|
| 112 |
+
notata_test_lines = list(csv.reader(notata_test_file, delimiter=",", quotechar=None))[1:]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
print("Original:")
|
| 116 |
+
print("tata train: %d" % (len(tata_train_lines)))
|
| 117 |
+
print("notata train: %d" % (len(notata_train_lines)))
|
| 118 |
+
print("tata test: %d" % (len(tata_test_lines)))
|
| 119 |
+
print("tata test: %d" % (len(notata_test_lines)))
|
| 120 |
+
|
| 121 |
+
random.shuffle(tata_train_lines)
|
| 122 |
+
random.shuffle(notata_train_lines)
|
| 123 |
+
random.shuffle(tata_test_lines)
|
| 124 |
+
random.shuffle(notata_test_lines)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
notata_train_lines = notata_train_lines[:len(tata_train_lines)]
|
| 128 |
+
notata_test_lines = notata_test_lines[:len(tata_test_lines)]
|
| 129 |
+
with open(os.path.join(args.file_path, "notata_test_id"), "w") as f:
|
| 130 |
+
tsv_w = csv.writer(f, delimiter=',')
|
| 131 |
+
tsv_w.writerow(["index", "chrom", "start", "end", "name", "strand", "keys", "id"])
|
| 132 |
+
for line in notata_test_lines:
|
| 133 |
+
tsv_w.writerow([line[0], line[1], line[2], line[3], line[4], line[5], line[7], line[9]])
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# print("After:")
|
| 138 |
+
# print("tata train: %d" % (len(tata_train_lines)))
|
| 139 |
+
# print("notata train: %d" % (len(notata_train_lines)))
|
| 140 |
+
# print("tata test: %d" % (len(tata_test_lines)))
|
| 141 |
+
# print("tata test: %d" % (len(notata_test_lines)))
|
| 142 |
+
|
| 143 |
+
# train_lines = tata_train_lines + notata_train_lines
|
| 144 |
+
# test_lines = tata_test_lines + notata_test_lines
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# output_path = args.output_path if args.output_path is not None else args.file_path
|
| 148 |
+
|
| 149 |
+
# write_file(test_lines, output_path+"/dev.tsv", args.kmer, head=False, seq_index=8, label_index=6)
|
| 150 |
+
# write_file(train_lines, output_path+"/train.tsv", args.kmer, head=False, seq_index=8, label_index=6)
|
| 151 |
+
# write_file(tata_test_lines, output_path+"/tata_dev.tsv", args.kmer, head=False, seq_index=8, label_index=6)
|
| 152 |
+
# write_file(tata_train_lines, output_path+"/tata_train.tsv", args.kmer, head=False, seq_index=8, label_index=6)
|
| 153 |
+
# write_file(notata_test_lines, output_path+"/notata_dev.tsv", args.kmer, head=False, seq_index=8, label_index=6)
|
| 154 |
+
# write_file(notata_train_lines, output_path+"/notata_train.tsv", args.kmer, head=False, seq_index=8, label_index=6)
|
| 155 |
+
|
| 156 |
+
# Process_1000_kmer(args, test_lines, train_lines, tata_test_lines, tata_train_lines, notata_test_lines, notata_train_lines)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def Process_1000_kmer(args, test_lines=None, train_lines=None, tata_test_lines=None, tata_train_lines=None, notata_test_lines=None, notata_train_lines=None):
|
| 160 |
+
|
| 161 |
+
LOAD = True
|
| 162 |
+
output_path = args.output_path if args.output_path is not None else args.file_path
|
| 163 |
+
|
| 164 |
+
if test_lines == None:
|
| 165 |
+
path1 = os.path.join(args.file_path,"dev.tsv")
|
| 166 |
+
path2 = os.path.join(args.file_path,"train.tsv")
|
| 167 |
+
path3 = os.path.join(args.file_path,"tata_dev.tsv")
|
| 168 |
+
path4 = os.path.join(args.file_path,"tata_train.tsv")
|
| 169 |
+
path5 = os.path.join(args.file_path,"notata_dev.tsv")
|
| 170 |
+
path6 = os.path.join(args.file_path,"notata_train.tsv")
|
| 171 |
+
|
| 172 |
+
file1 = open(path1, "r", encoding="utf-8-sig")
|
| 173 |
+
file2 = open(path2, "r", encoding="utf-8-sig")
|
| 174 |
+
file3 = open(path3, "r", encoding="utf-8-sig")
|
| 175 |
+
file4 = open(path4, "r", encoding="utf-8-sig")
|
| 176 |
+
file5 = open(path5, "r", encoding="utf-8-sig")
|
| 177 |
+
file6 = open(path6, "r", encoding="utf-8-sig")
|
| 178 |
+
|
| 179 |
+
test_lines = list(csv.reader(file1, delimiter="\t", quotechar=None))
|
| 180 |
+
train_lines = list(csv.reader(file2, delimiter="\t", quotechar=None))
|
| 181 |
+
tata_test_lines = list(csv.reader(file3, delimiter="\t", quotechar=None))
|
| 182 |
+
tata_train_lines = list(csv.reader(file4, delimiter="\t", quotechar=None))
|
| 183 |
+
notata_test_lines = list(csv.reader(file5, delimiter="\t", quotechar=None))
|
| 184 |
+
notata_train_lines = list(csv.reader(file6, delimiter="\t", quotechar=None))
|
| 185 |
+
|
| 186 |
+
LOAD = False
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
for kmer in range(3,7):
|
| 191 |
+
|
| 192 |
+
print(kmer)
|
| 193 |
+
root_path = os.path.join(output_path, str(kmer))
|
| 194 |
+
if not os.path.exists(root_path):
|
| 195 |
+
os.makedirs(root_path)
|
| 196 |
+
|
| 197 |
+
all_path = os.path.join(root_path, "all")
|
| 198 |
+
# tata_path = os.path.join(root_path, "tata")
|
| 199 |
+
notata_path = os.path.join(root_path, "notata")
|
| 200 |
+
os.makedirs(all_path)
|
| 201 |
+
# os.makedirs(tata_path)
|
| 202 |
+
os.makedirs(notata_path)
|
| 203 |
+
|
| 204 |
+
if LOAD:
|
| 205 |
+
seq_index=8
|
| 206 |
+
label_index=6
|
| 207 |
+
else:
|
| 208 |
+
seq_index=0
|
| 209 |
+
label_index=1
|
| 210 |
+
|
| 211 |
+
print("writing dev")
|
| 212 |
+
write_file(test_lines, os.path.join(all_path,"dev.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
|
| 213 |
+
print("writing train")
|
| 214 |
+
write_file(train_lines, os.path.join(all_path,"train.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
|
| 215 |
+
# print("writing tata dev")
|
| 216 |
+
# write_file(tata_test_lines, os.path.join(tata_path,"dev.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
|
| 217 |
+
# print("writing tata train")
|
| 218 |
+
# write_file(tata_train_lines, os.path.join(tata_path,"train.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
|
| 219 |
+
print("writing notata dev")
|
| 220 |
+
write_file(notata_test_lines, os.path.join(notata_path,"dev.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
|
| 221 |
+
print("writing notata train")
|
| 222 |
+
write_file(notata_train_lines, os.path.join(notata_path,"train.tsv"), kmer, head=False, seq_index=seq_index, label_index=label_index)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def Process_splice(args):
|
| 226 |
+
# X_train = np.load(os.path.join(args.file_path, "x_train.npy"))
|
| 227 |
+
# X_dev = np.load(os.path.join(args.file_path, "x_dev.npy"))
|
| 228 |
+
# Y_train = np.load(os.path.join(args.file_path, "y_train.npy"))
|
| 229 |
+
# Y_dev = np.load(os.path.join(args.file_path, "y_dev.npy"))
|
| 230 |
+
|
| 231 |
+
# assert len(X_train) == len(Y_train)
|
| 232 |
+
# assert len(X_dev) == len(Y_dev)
|
| 233 |
+
|
| 234 |
+
# for kmer in range(3,7):
|
| 235 |
+
# root_path = os.path.join(args.file_path, str(kmer))
|
| 236 |
+
# os.makedirs(root_path)
|
| 237 |
+
# f_train = open(os.path.join(root_path, "train.tsv"), "wt")
|
| 238 |
+
# f_dev = open(os.path.join(root_path, "dev.tsv"), "wt")
|
| 239 |
+
# tsv_train = csv.writer(f_train, delimiter='\t')
|
| 240 |
+
# tsv_dev = csv.writer(f_dev, delimiter='\t')
|
| 241 |
+
# tsv_train.writerow(["seq", "label"])
|
| 242 |
+
# tsv_dev.writerow(["seq", "label"])
|
| 243 |
+
|
| 244 |
+
# for i, seq in enumerate(X_train):
|
| 245 |
+
# sequence = get_kmer_sentence(str(seq), kmer)
|
| 246 |
+
# tsv_train.writerow([sequence, int(Y_train[i])])
|
| 247 |
+
|
| 248 |
+
# for j, seq in enumerate(X_dev):
|
| 249 |
+
# sequence = get_kmer_sentence(str(seq), kmer)
|
| 250 |
+
# tsv_dev.writerow([sequence, int(Y_dev[j])])
|
| 251 |
+
|
| 252 |
+
X_test = np.load(os.path.join(args.file_path, "x_test.npy"))
|
| 253 |
+
Y_test = np.load(os.path.join(args.file_path, "y_test.npy"))
|
| 254 |
+
|
| 255 |
+
assert len(X_test) == len(Y_test)
|
| 256 |
+
|
| 257 |
+
for kmer in range(3,7):
|
| 258 |
+
root_path = os.path.join(args.file_path, str(kmer))
|
| 259 |
+
os.makedirs(root_path)
|
| 260 |
+
f_test = open(os.path.join(root_path, "dev.tsv"), "wt")
|
| 261 |
+
tsv_test = csv.writer(f_test, delimiter='\t')
|
| 262 |
+
tsv_test.writerow(["seq", "label"])
|
| 263 |
+
|
| 264 |
+
for i, seq in enumerate(X_test):
|
| 265 |
+
sequence = get_kmer_sentence(str(seq), kmer)
|
| 266 |
+
label = int(np.where(Y_test[i]==1)[0])
|
| 267 |
+
tsv_test.writerow([sequence, label])
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def Process_prom_core(args):
|
| 271 |
+
random.seed(args.seed)
|
| 272 |
+
|
| 273 |
+
tata = args.file_path + "/TATA.csv"
|
| 274 |
+
notata = args.file_path + "/noTATA.csv"
|
| 275 |
+
tata_file = open(tata, "r", encoding="utf-8-sig")
|
| 276 |
+
notata_file = open(notata, "r", encoding="utf-8-sig")
|
| 277 |
+
tata_lines = list(csv.reader(tata_file, delimiter=",", quotechar=None))[1:]
|
| 278 |
+
notata_lines = list(csv.reader(notata_file, delimiter=",", quotechar=None))[1:]
|
| 279 |
+
|
| 280 |
+
random.shuffle(tata_lines)
|
| 281 |
+
random.shuffle(notata_lines)
|
| 282 |
+
|
| 283 |
+
num_tata_test = int(0.1*len(tata_lines))
|
| 284 |
+
tata_test_lines = tata_lines[:num_tata_test]
|
| 285 |
+
num_notata_test = int(0.1*len(notata_lines))
|
| 286 |
+
notata_test_lines = notata_lines[:num_notata_test]
|
| 287 |
+
|
| 288 |
+
train_lines = tata_lines[num_tata_test:] + notata_lines[num_notata_test:]
|
| 289 |
+
if args.dev:
|
| 290 |
+
num_dev = int(len(rest_lines)/9.0)
|
| 291 |
+
dev_lines = train_lines[:num_dev]
|
| 292 |
+
train_lines = train_lines[num_dev:]
|
| 293 |
+
else:
|
| 294 |
+
dev_lines = tata_test_lines + notata_test_lines
|
| 295 |
+
|
| 296 |
+
print("Number train examples: %d" % (len(train_lines)))
|
| 297 |
+
print("Number dev examples: %d" % (len(dev_lines)))
|
| 298 |
+
|
| 299 |
+
for kmer in range(3,7):
|
| 300 |
+
root_path = os.path.join(args.file_path,str(kmer))
|
| 301 |
+
tata_path = os.path.join(root_path, "tata")
|
| 302 |
+
notata_path = os.path.join(root_path, "notata")
|
| 303 |
+
os.makedirs(tata_path)
|
| 304 |
+
os.makedirs(notata_path)
|
| 305 |
+
|
| 306 |
+
write_file(tata_test_lines, os.path.join(tata_path,"dev.tsv"), kmer, head=False, seq_index=1, label_index=2)
|
| 307 |
+
write_file(notata_test_lines, os.path.join(notata_path,"dev.tsv"), kmer, head=False, seq_index=1, label_index=2)
|
| 308 |
+
write_file(train_lines, os.path.join(root_path,"train.tsv"), kmer, head=False, seq_index=1, label_index=2)
|
| 309 |
+
write_file(dev_lines, os.path.join(root_path,"dev.tsv"), kmer, head=False, seq_index=1, label_index=2)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def Process_pair(args):
|
| 313 |
+
random.seed(args.seed)
|
| 314 |
+
|
| 315 |
+
root_path = args.file_path.split('/')[-1]
|
| 316 |
+
train_seq1_file = open(args.file_path+"/"+root_path+"_enhancer.fasta", "r")
|
| 317 |
+
train_seq2_file = open(args.file_path+"/"+root_path+"_promoter.fasta", "r")
|
| 318 |
+
train_label_file = open(args.file_path+"/"+root_path+"_label.txt", "r")
|
| 319 |
+
test_seq1_file = open(args.file_path+"/"+root_path+"_enhancer_test.fasta", "r")
|
| 320 |
+
test_seq2_file = open(args.file_path+"/"+root_path+"_promoter_test.fasta", "r")
|
| 321 |
+
test_label_file = open(args.file_path+"/"+root_path+"_label_test.txt", "r")
|
| 322 |
+
|
| 323 |
+
train_seq1 = train_seq1_file.readlines()
|
| 324 |
+
train_seq2 = train_seq2_file.readlines()
|
| 325 |
+
train_label = train_label_file.readlines()
|
| 326 |
+
test_seq1 = test_seq1_file.readlines()
|
| 327 |
+
test_seq2 = test_seq2_file.readlines()
|
| 328 |
+
test_label = test_label_file.readlines()
|
| 329 |
+
|
| 330 |
+
train_lines = []
|
| 331 |
+
test_lines = []
|
| 332 |
+
for i in range(len(train_label)):
|
| 333 |
+
train_lines.append([train_seq1[2*i+1], train_seq2[2*i+1], train_label[i]])
|
| 334 |
+
for i in range(len(test_label)):
|
| 335 |
+
test_lines.append([test_seq1[2*i+1], test_seq2[2*i+1], test_label[i]])
|
| 336 |
+
|
| 337 |
+
random.shuffle(train_lines)
|
| 338 |
+
|
| 339 |
+
if args.dev:
|
| 340 |
+
num_dev = int(len(train_lines)/10)
|
| 341 |
+
dev_lines = train_lines[:num_dev]
|
| 342 |
+
train_lines = train_lines[num_dev:]
|
| 343 |
+
|
| 344 |
+
output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(args.kmer))
|
| 345 |
+
if not os.path.exists(output_path):
|
| 346 |
+
os.makedirs(output_path)
|
| 347 |
+
|
| 348 |
+
f_train = open(os.path.join(output_path, "train.tsv"), 'wt')
|
| 349 |
+
train_w = csv.writer(f_train, delimiter='\t')
|
| 350 |
+
train_w.writerow(["seq1", "seq2", "label"])
|
| 351 |
+
if args.dev:
|
| 352 |
+
f_dev = open(os.path.join(output_path, "dev.tsv"), 'wt')
|
| 353 |
+
dev_w = csv.writer(f_dev, delimiter='\t')
|
| 354 |
+
dev_w.writerow(["seq1", "seq2", "label"])
|
| 355 |
+
os.makedirs(os.path.join(output_path, "test"))
|
| 356 |
+
f_test = open(os.path.join(output_path, "test", "dev.tsv"), 'wt')
|
| 357 |
+
test_w = csv.writer(f_test, delimiter='\t')
|
| 358 |
+
test_w.writerow(["seq1", "seq2", "label"])
|
| 359 |
+
else:
|
| 360 |
+
f_test = open(os.path.join(output_path, "dev.tsv"), 'wt')
|
| 361 |
+
test_w = csv.writer(f_test, delimiter='\t')
|
| 362 |
+
test_w.writerow(["seq1", "seq2", "label"])
|
| 363 |
+
|
| 364 |
+
def write_file_pair(lines, writer, seq1_index=0, seq2_index=1, label_index=2):
|
| 365 |
+
for line in lines:
|
| 366 |
+
seq1 = get_kmer_sentence(line[seq1_index],args.kmer)
|
| 367 |
+
seq2 = get_kmer_sentence(line[seq2_index],args.kmer)
|
| 368 |
+
writer.writerow([seq1, seq2, str(int(line[label_index]))])
|
| 369 |
+
|
| 370 |
+
write_file_pair(train_lines, train_w)
|
| 371 |
+
write_file_pair(test_lines, test_w)
|
| 372 |
+
|
| 373 |
+
if args.dev:
|
| 374 |
+
write_file_pair(dev_lines, dev_w)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def Process_p53_mut(args):
|
| 378 |
+
random.seed(args.seed)
|
| 379 |
+
|
| 380 |
+
dev = os.path.join(args.file_path, "dev.csv")
|
| 381 |
+
dev_file = open(dev, "r", encoding="utf-8-sig")
|
| 382 |
+
|
| 383 |
+
lines = list(csv.reader(dev_file, delimiter=",", quotechar=None))[1:]
|
| 384 |
+
|
| 385 |
+
print(lines[0])
|
| 386 |
+
|
| 387 |
+
for kmer in range(3, 7):
|
| 388 |
+
output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(kmer))
|
| 389 |
+
if not os.path.exists(output_path):
|
| 390 |
+
os.makedirs(output_path)
|
| 391 |
+
|
| 392 |
+
write_file(lines, os.path.join(output_path, "dev.tsv"), kmer, head=True, seq_index=2, label_index=None)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def Process_p53(args):
|
| 396 |
+
random.seed(args.seed)
|
| 397 |
+
|
| 398 |
+
train = os.path.join(args.file_path, "train.csv")
|
| 399 |
+
test = os.path.join(args.file_path, "test.csv")
|
| 400 |
+
train_file = open(train, "r", encoding="utf-8-sig")
|
| 401 |
+
test_file = open(test, "r", encoding="utf-8-sig")
|
| 402 |
+
|
| 403 |
+
train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
|
| 404 |
+
test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
|
| 405 |
+
lines = train_lines + test_lines
|
| 406 |
+
|
| 407 |
+
max_length = 0
|
| 408 |
+
for line in lines:
|
| 409 |
+
if len(line[2]) > max_length:
|
| 410 |
+
max_length = len(line[2])
|
| 411 |
+
|
| 412 |
+
random.shuffle(train_lines)
|
| 413 |
+
random.shuffle(test_lines)
|
| 414 |
+
|
| 415 |
+
if args.dev:
|
| 416 |
+
num_dev = int(len(train_lines)/9)
|
| 417 |
+
dev_lines = train_lines[:num_dev]
|
| 418 |
+
train_lines = train_lines[num_dev:]
|
| 419 |
+
|
| 420 |
+
print(train_lines[0])
|
| 421 |
+
|
| 422 |
+
for kmer in range(3, 7):
|
| 423 |
+
output_path = args.output_path if args.output_path else os.path.join(args.file_path, str(kmer))
|
| 424 |
+
if not os.path.exists(output_path):
|
| 425 |
+
os.makedirs(output_path)
|
| 426 |
+
|
| 427 |
+
write_file(train_lines, os.path.join(output_path, "train.tsv"), kmer, head=True, seq_index=2, label_index=3)
|
| 428 |
+
if args.dev:
|
| 429 |
+
write_file(dev_lines, os.path.join(output_path, "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
|
| 430 |
+
os.makedirs(os.path.join(output_path, "test"))
|
| 431 |
+
write_file(test_lines, os.path.join(output_path, "test", "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
|
| 432 |
+
else:
|
| 433 |
+
write_file(test_lines, os.path.join(output_path, "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
|
| 434 |
+
|
| 435 |
+
print("max length: %d" % (max_length))
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def Seperate_p53(args):
|
| 439 |
+
random.seed(args.seed)
|
| 440 |
+
|
| 441 |
+
train = os.path.join(args.file_path, "train.csv")
|
| 442 |
+
test = os.path.join(args.file_path, "test.csv")
|
| 443 |
+
train_file = open(train, "r", encoding="utf-8-sig")
|
| 444 |
+
test_file = open(test, "r", encoding="utf-8-sig")
|
| 445 |
+
|
| 446 |
+
train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
|
| 447 |
+
test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
|
| 448 |
+
lines = train_lines + test_lines
|
| 449 |
+
|
| 450 |
+
POS = []
|
| 451 |
+
NEG = []
|
| 452 |
+
|
| 453 |
+
for line in lines:
|
| 454 |
+
if str(line[-1]) == '0':
|
| 455 |
+
NEG.append([line[-2], line[-1]])
|
| 456 |
+
else:
|
| 457 |
+
POS.append([line[-2], line[-1]])
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
for kmer in range(3,7):
|
| 462 |
+
os.makedirs(os.path.join(args.file_path, "POS", str(kmer)))
|
| 463 |
+
os.makedirs(os.path.join(args.file_path, "NEG", str(kmer)))
|
| 464 |
+
|
| 465 |
+
write_file(POS, os.path.join(args.file_path, "POS", str(kmer), "dev.tsv"), kmer=kmer, head=True, seq_index=0, label_index=1)
|
| 466 |
+
write_file(NEG, os.path.join(args.file_path, "NEG", str(kmer), "dev.tsv"), kmer=kmer, head=True, seq_index=0, label_index=1)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def Generate_prom_train_dev(args):
|
| 471 |
+
# read TATA and noTATA files
|
| 472 |
+
tata = args.file_path + "/noTATA_249to50.tsv"
|
| 473 |
+
notata = args.file_path + "/TATA_249to50.tsv"
|
| 474 |
+
tata_file = open(tata, "r", encoding="utf-8-sig")
|
| 475 |
+
notata_file = open(notata, "r", encoding="utf-8-sig")
|
| 476 |
+
tata_lines = list(csv.reader(tata_file, delimiter="\t", quotechar=None))[1:]
|
| 477 |
+
notata_lines = list(csv.reader(notata_file, delimiter="\t", quotechar=None))[1:]
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
# shuffle all the data and split them
|
| 481 |
+
random.shuffle(tata_lines)
|
| 482 |
+
random.shuffle(notata_lines)
|
| 483 |
+
num_tata_test = int(len(tata_lines)*0.1)
|
| 484 |
+
tata_test_lines = tata_lines[:num_tata_test]
|
| 485 |
+
num_notata_test = int(len(notata_lines)*0.1)
|
| 486 |
+
notata_test_lines = notata_lines[:num_notata_test]
|
| 487 |
+
train_lines = tata_lines[num_tata_test:] + notata_lines[num_notata_test:]
|
| 488 |
+
test_lines = tata_test_lines + notata_test_lines
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
write_file(train_lines, args.file_path+"/train.tsv", args.kmer)
|
| 492 |
+
write_file(test_lines, args.file_path+"/dev.tsv", args.kmer)
|
| 493 |
+
write_file(tata_test_lines, args.file_path+"/tata_dev.tsv", args.kmer)
|
| 494 |
+
write_file(notata_test_lines, args.file_path+"/notata_dev.tsv", args.kmer)
|
| 495 |
+
|
| 496 |
+
def Process_690(args):
|
| 497 |
+
path = args.file_path
|
| 498 |
+
all_folders = os.listdir(path)
|
| 499 |
+
|
| 500 |
+
count = 0
|
| 501 |
+
|
| 502 |
+
for folder in all_folders:
|
| 503 |
+
# load data
|
| 504 |
+
train_seq_path = os.path.join(args.file_path, folder, "train", "sequences_alph.npy")
|
| 505 |
+
test_seq_path = os.path.join(args.file_path, folder, "test", "sequences_alph.npy")
|
| 506 |
+
train_lab_path = os.path.join(args.file_path, folder, "train", "targets.npy")
|
| 507 |
+
test_lab_path = os.path.join(args.file_path, folder, "test", "targets.npy")
|
| 508 |
+
train_sequences = np.load(train_seq_path)
|
| 509 |
+
test_sequences = np.load(test_seq_path)
|
| 510 |
+
train_labels = np.load(train_lab_path)
|
| 511 |
+
test_labels = np.load(test_lab_path)
|
| 512 |
+
|
| 513 |
+
train_sequences = train_sequences.reshape(train_sequences.shape[0],1)
|
| 514 |
+
test_sequences = test_sequences.reshape(test_sequences.shape[0],1)
|
| 515 |
+
train_labels = train_labels.reshape(train_labels.shape[0],1)
|
| 516 |
+
test_labels = test_labels.reshape(test_labels.shape[0],1)
|
| 517 |
+
|
| 518 |
+
# concat sequence and labels together
|
| 519 |
+
trains = list(np.concatenate((train_sequences, train_labels), axis=1))
|
| 520 |
+
tests = list(np.concatenate((test_sequences, test_labels), axis=1))
|
| 521 |
+
|
| 522 |
+
random.seed(args.seed)
|
| 523 |
+
random.shuffle(trains)
|
| 524 |
+
random.shuffle(trains)
|
| 525 |
+
random.shuffle(tests)
|
| 526 |
+
random.shuffle(tests)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
# make output path
|
| 530 |
+
output_path = os.path.join(args.output_path, str(args.kmer), folder)
|
| 531 |
+
if not os.path.exists(output_path):
|
| 532 |
+
os.makedirs(output_path)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
# write files
|
| 537 |
+
f_train = open(os.path.join(output_path, "train.tsv"), 'wt')
|
| 538 |
+
tsv_train = csv.writer(f_train, delimiter='\t')
|
| 539 |
+
tsv_train.writerow(["sequence", "label"])
|
| 540 |
+
for i in range(len(trains)):
|
| 541 |
+
sentence = get_kmer_sentence(trains[i][0].decode("utf-8"), args.kmer)
|
| 542 |
+
tsv_train.writerow([sentence, int(trains[i][1])])
|
| 543 |
+
|
| 544 |
+
f_dev = open(os.path.join(output_path, "dev.tsv"), 'wt')
|
| 545 |
+
tsv_dev = csv.writer(f_dev, delimiter='\t')
|
| 546 |
+
tsv_dev.writerow(["sequence", "label"])
|
| 547 |
+
for i in range(len(tests)):
|
| 548 |
+
sentence = get_kmer_sentence(tests[i][0].decode("utf-8"), args.kmer)
|
| 549 |
+
tsv_dev.writerow([sentence, int(tests[i][1])])
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
count += 1
|
| 553 |
+
print("Finish %s folders" % (count))
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def Process_mouse(args):
|
| 557 |
+
random.seed(args.seed)
|
| 558 |
+
|
| 559 |
+
files = os.listdir(args.file_path)
|
| 560 |
+
|
| 561 |
+
try:
|
| 562 |
+
files.remove("3")
|
| 563 |
+
files.remove("4")
|
| 564 |
+
files.remove("5")
|
| 565 |
+
files.remove("6")
|
| 566 |
+
except ValueError:
|
| 567 |
+
files = files
|
| 568 |
+
|
| 569 |
+
files.sort()
|
| 570 |
+
assert len(files) % 2 == 0
|
| 571 |
+
|
| 572 |
+
num_task = int(len(files)/2)
|
| 573 |
+
|
| 574 |
+
max_length = 0
|
| 575 |
+
|
| 576 |
+
for i in range(num_task):
|
| 577 |
+
index = str(i) if i > 9 else "0" + str(i)
|
| 578 |
+
|
| 579 |
+
test_name = files[2*i].replace("test", "train")
|
| 580 |
+
train_name = files[2*i+1]
|
| 581 |
+
assert test_name == train_name
|
| 582 |
+
|
| 583 |
+
test_file = os.path.join(args.file_path, files[2*i])
|
| 584 |
+
train_file = os.path.join(args.file_path, files[2*i+1])
|
| 585 |
+
train_file = open(train_file, "r", encoding="utf-8-sig")
|
| 586 |
+
test_file = open(test_file, "r", encoding="utf-8-sig")
|
| 587 |
+
train_lines = list(csv.reader(train_file, delimiter=",", quotechar=None))[1:]
|
| 588 |
+
test_lines = list(csv.reader(test_file, delimiter=",", quotechar=None))[1:]
|
| 589 |
+
|
| 590 |
+
print("dataset %d : %d lines" % (i, len(train_lines)))
|
| 591 |
+
|
| 592 |
+
# random.shuffle(train_lines)
|
| 593 |
+
|
| 594 |
+
# for kmer in range(3, 7):
|
| 595 |
+
# os.makedirs(os.path.join(args.file_path, str(kmer), index))
|
| 596 |
+
# write_file(train_lines, os.path.join(args.file_path, str(kmer), index, "train.tsv"), kmer, head=True, seq_index=2, label_index=3)
|
| 597 |
+
# write_file(test_lines, os.path.join(args.file_path, str(kmer), index, "dev.tsv"), kmer, head=True, seq_index=2, label_index=3)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def Process(args):
|
| 602 |
+
if args.output_path != None:
|
| 603 |
+
output_path = args.output_path
|
| 604 |
+
else:
|
| 605 |
+
root_path = "/".join(args.file_path.split("/")[:-1]) + "/" + str(args.kmer) + "/"
|
| 606 |
+
output_path = root_path + args.file_path.split("/")[-1]
|
| 607 |
+
if not os.path.exists(root_path):
|
| 608 |
+
os.makedirs(root_path)
|
| 609 |
+
|
| 610 |
+
old_file = open(args.file_path, "r", encoding="utf-8-sig")
|
| 611 |
+
lines = list(csv.reader(old_file, delimiter=args.delimiter, quotechar=None))
|
| 612 |
+
|
| 613 |
+
write_file(lines, output_path, args.kmer, head=args.head, seq_index=args.seq_index, label_index=args.label_index)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
def main():
|
| 617 |
+
parser = argparse.ArgumentParser()
|
| 618 |
+
parser.add_argument(
|
| 619 |
+
"--kmer",
|
| 620 |
+
default=1,
|
| 621 |
+
type=int,
|
| 622 |
+
help="K-mer",
|
| 623 |
+
)
|
| 624 |
+
parser.add_argument(
|
| 625 |
+
"--seed",
|
| 626 |
+
default=24,
|
| 627 |
+
type=int,
|
| 628 |
+
help="Which random seed to use",
|
| 629 |
+
)
|
| 630 |
+
parser.add_argument(
|
| 631 |
+
"--task",
|
| 632 |
+
default="",
|
| 633 |
+
type=str,
|
| 634 |
+
help="which task to do",
|
| 635 |
+
)
|
| 636 |
+
parser.add_argument(
|
| 637 |
+
"--file_path",
|
| 638 |
+
default=None,
|
| 639 |
+
type=str,
|
| 640 |
+
help="The path of the file to be processed",
|
| 641 |
+
)
|
| 642 |
+
parser.add_argument(
|
| 643 |
+
"--output_path",
|
| 644 |
+
default=None,
|
| 645 |
+
type=str,
|
| 646 |
+
help="The path of the processed data",
|
| 647 |
+
)
|
| 648 |
+
parser.add_argument(
|
| 649 |
+
"--delimiter",
|
| 650 |
+
default=',',
|
| 651 |
+
type=str,
|
| 652 |
+
help="The path of the processed data",
|
| 653 |
+
)
|
| 654 |
+
parser.add_argument(
|
| 655 |
+
"--head",
|
| 656 |
+
action="store_true",
|
| 657 |
+
help="The path of the processed data",
|
| 658 |
+
)
|
| 659 |
+
parser.add_argument(
|
| 660 |
+
"--dev",
|
| 661 |
+
action="store_true",
|
| 662 |
+
help="Use this flag to split data as (8:1:1), else (9:1)",
|
| 663 |
+
)
|
| 664 |
+
parser.add_argument(
|
| 665 |
+
"--seq_index",
|
| 666 |
+
default=2,
|
| 667 |
+
type=int,
|
| 668 |
+
help="index of seq in the original csv file",
|
| 669 |
+
)
|
| 670 |
+
parser.add_argument(
|
| 671 |
+
"--label_index",
|
| 672 |
+
default=3,
|
| 673 |
+
type=int,
|
| 674 |
+
help="index of label in the original csv file",
|
| 675 |
+
)
|
| 676 |
+
args = parser.parse_args()
|
| 677 |
+
|
| 678 |
+
if args.task == "generate_prom":
|
| 679 |
+
Generate_prom_train_dev(args)
|
| 680 |
+
elif args.task == "shuffle":
|
| 681 |
+
Shuffle(args)
|
| 682 |
+
elif args.task == "find_train":
|
| 683 |
+
Find_train(args)
|
| 684 |
+
elif args.task == "prom_1000":
|
| 685 |
+
Process_1000(args)
|
| 686 |
+
elif args.task == "prom_1000_kmer":
|
| 687 |
+
Process_1000_kmer(args)
|
| 688 |
+
elif args.task == "splice":
|
| 689 |
+
Process_splice(args)
|
| 690 |
+
elif args.task == "pair":
|
| 691 |
+
Process_pair(args)
|
| 692 |
+
elif args.task == "p53":
|
| 693 |
+
Process_p53(args)
|
| 694 |
+
elif args.task == "p53_mut":
|
| 695 |
+
Process_p53_mut(args)
|
| 696 |
+
elif args.task == "sep_p53":
|
| 697 |
+
Seperate_p53(args)
|
| 698 |
+
elif args.task == "690":
|
| 699 |
+
Process_690(args)
|
| 700 |
+
elif args.task == "mouse":
|
| 701 |
+
Process_mouse(args)
|
| 702 |
+
elif args.task == "prom-core":
|
| 703 |
+
Process_prom_core(args)
|
| 704 |
+
else:
|
| 705 |
+
Process(args)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
if __name__ == "__main__":
|
| 713 |
+
main()
|
examples/data_process_template/process_ner.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import os
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
from process_pretrain_data import get_kmer_sequence
|
| 8 |
+
from multiprocessing import Pool
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_example(X, Y, kmer, index):
|
| 12 |
+
# assert X.shape[0] == Y.shape[0]
|
| 13 |
+
lines = []
|
| 14 |
+
for j in range(len(X)):
|
| 15 |
+
if j % 1000 == 0:
|
| 16 |
+
print("%s : %s" % (index, j))
|
| 17 |
+
|
| 18 |
+
label = list(np.zeros(200,dtype=int)) + list(np.where(Y[j]==1)[1]) + list(np.zeros(201-kmer,dtype=int))
|
| 19 |
+
|
| 20 |
+
sequence = get_kmer_sequence(X[j].decode("utf-8"), kmer)
|
| 21 |
+
lines.append([sequence, label])
|
| 22 |
+
|
| 23 |
+
return lines
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def Process(args):
|
| 27 |
+
filename = args.file_path
|
| 28 |
+
h5 = h5py.File(filename, "r")
|
| 29 |
+
num_chunks = len(h5.keys())//2
|
| 30 |
+
keys = list(h5.keys())[:num_chunks]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
X = []
|
| 34 |
+
|
| 35 |
+
for i, key in enumerate(keys):
|
| 36 |
+
x_key = key
|
| 37 |
+
y_key = x_key.replace("X","Y")
|
| 38 |
+
|
| 39 |
+
X_l = h5[x_key]
|
| 40 |
+
Y_l = h5[y_key][0]
|
| 41 |
+
|
| 42 |
+
X.extend(X_l)
|
| 43 |
+
|
| 44 |
+
if i == 0:
|
| 45 |
+
Y = Y_l
|
| 46 |
+
else:
|
| 47 |
+
Y = np.concatenate([Y, Y_l], axis=0)
|
| 48 |
+
|
| 49 |
+
print("%d : %d, %d, %s" % (i, len(X), Y.shape[0], str(key)))
|
| 50 |
+
|
| 51 |
+
print(len(X))
|
| 52 |
+
print(len(Y))
|
| 53 |
+
|
| 54 |
+
n_proc = int(args.n_process)
|
| 55 |
+
print("number of processes for converting feature: " + str(n_proc))
|
| 56 |
+
p = Pool(n_proc)
|
| 57 |
+
indexes = [0]
|
| 58 |
+
len_slice = int(len(X)/n_proc)
|
| 59 |
+
for i in range(1, n_proc+1):
|
| 60 |
+
if i != n_proc:
|
| 61 |
+
indexes.append(len_slice*(i))
|
| 62 |
+
else:
|
| 63 |
+
indexes.append(len(X))
|
| 64 |
+
|
| 65 |
+
results = []
|
| 66 |
+
|
| 67 |
+
for i in range(n_proc):
|
| 68 |
+
results.append(p.apply_async(generate_example, args=(X[indexes[i]:indexes[i+1]], Y[indexes[i]:indexes[i+1]], args.kmer, i)))
|
| 69 |
+
print(str(i+1) + ' processor started !')
|
| 70 |
+
|
| 71 |
+
p.close()
|
| 72 |
+
p.join()
|
| 73 |
+
|
| 74 |
+
lines = []
|
| 75 |
+
for result in results:
|
| 76 |
+
lines.extend(result.get())
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
path = "/".join(args.file_path.split('/')[:-1]) + "/" + str(args.kmer) + "/train.txt"
|
| 80 |
+
print(path)
|
| 81 |
+
file = open(path, "w")
|
| 82 |
+
for line in lines:
|
| 83 |
+
for k, word in enumerate(line[0]):
|
| 84 |
+
file.write(str(word) + " " + str(line[1][k]) + "\n")
|
| 85 |
+
file.write("\n")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def main():
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--kmer",
|
| 98 |
+
default=1,
|
| 99 |
+
type=int,
|
| 100 |
+
help="K-mer",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--n_process",
|
| 104 |
+
default=24,
|
| 105 |
+
type=int,
|
| 106 |
+
help="Number of processes for data processing",
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--file_path",
|
| 110 |
+
default=None,
|
| 111 |
+
type=str,
|
| 112 |
+
help="The path of the file to be processed",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--output_path",
|
| 116 |
+
default=None,
|
| 117 |
+
type=str,
|
| 118 |
+
help="The path of the processed data",
|
| 119 |
+
)
|
| 120 |
+
args = parser.parse_args()
|
| 121 |
+
|
| 122 |
+
Process(args)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
examples/data_process_template/process_pretrain_data.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def cut_no_overlap(length, kmer=1, max_prob=0.5):
|
| 7 |
+
cuts = []
|
| 8 |
+
while length:
|
| 9 |
+
if length <= 509+kmer:
|
| 10 |
+
cuts.append(length)
|
| 11 |
+
break
|
| 12 |
+
else:
|
| 13 |
+
if random.random() > max_prob:
|
| 14 |
+
cut = max(int(random.random()*(509+kmer)), 5)
|
| 15 |
+
else:
|
| 16 |
+
cut = 509+kmer
|
| 17 |
+
cuts.append(cut)
|
| 18 |
+
length -= cut
|
| 19 |
+
|
| 20 |
+
return cuts
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def sampling(length, kmer=1, sampling_rate=1):
|
| 24 |
+
times = int(length*sampling_rate/256)
|
| 25 |
+
starts = []
|
| 26 |
+
ends = []
|
| 27 |
+
for i in range(times):
|
| 28 |
+
cut = max(int(random.random()*(509+kmer)), 5)
|
| 29 |
+
start = np.random.randint(length-kmer)
|
| 30 |
+
starts.append(start)
|
| 31 |
+
ends.append(start+cut)
|
| 32 |
+
|
| 33 |
+
return starts, ends
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def sampling_fix(length, kmer=1, sampling_rate=1, fix_length=10245):
|
| 37 |
+
times = int(length*sampling_rate/fix_length)
|
| 38 |
+
starts = []
|
| 39 |
+
ends = []
|
| 40 |
+
for i in range(times):
|
| 41 |
+
cut = fix_length
|
| 42 |
+
start = np.random.randint(length-6-fix_length)
|
| 43 |
+
starts.append(start)
|
| 44 |
+
ends.append(start+cut)
|
| 45 |
+
|
| 46 |
+
return starts, ends
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_kmer_sentence(original_string, kmer=1, stride=1):
|
| 50 |
+
if kmer == -1:
|
| 51 |
+
return original_string
|
| 52 |
+
|
| 53 |
+
sentence = ""
|
| 54 |
+
original_string = original_string.replace("\n", "")
|
| 55 |
+
i = 0
|
| 56 |
+
while i < len(original_string)-kmer:
|
| 57 |
+
sentence += original_string[i:i+kmer] + " "
|
| 58 |
+
i += stride
|
| 59 |
+
|
| 60 |
+
return sentence[:-1].strip("\"")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_kmer_sequence(original_string, kmer=1):
|
| 65 |
+
if kmer == -1:
|
| 66 |
+
return original_string
|
| 67 |
+
|
| 68 |
+
sequence = []
|
| 69 |
+
original_string = original_string.replace("\n", "")
|
| 70 |
+
for i in range(len(original_string)-kmer):
|
| 71 |
+
sequence.append(original_string[i:i+kmer])
|
| 72 |
+
|
| 73 |
+
sequence.append(original_string[-kmer:])
|
| 74 |
+
return sequence
|
| 75 |
+
|
| 76 |
+
def Process(args):
|
| 77 |
+
old_file = open(args.file_path, "r")
|
| 78 |
+
if args.output_path == None:
|
| 79 |
+
args.output_path = args.file_path
|
| 80 |
+
|
| 81 |
+
if args.sampling_rate!=1.0:
|
| 82 |
+
new_file_path = args.output_path + "_sam" + str(args.kmer)
|
| 83 |
+
else:
|
| 84 |
+
new_file_path = args.output_path + "_cut" + str(args.kmer)
|
| 85 |
+
new_file = open(new_file_path, "w")
|
| 86 |
+
line = old_file.readline()
|
| 87 |
+
while line:
|
| 88 |
+
line_length = len(line)
|
| 89 |
+
if args.sampling_rate != 1.0:
|
| 90 |
+
starts, ends = sampling_fix(length=line_length, kmer=args.kmer, sampling_rate=args.sampling_rate, fix_length=args.length)
|
| 91 |
+
for i in range(len(starts)):
|
| 92 |
+
new_line = line[starts[i]:ends[i]]
|
| 93 |
+
sentence = get_kmer_sentence(new_line, kmer=args.kmer)
|
| 94 |
+
new_file.write(sentence + "\n")
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
cuts = cut_no_overlap(length=line_length, kmer=args.kmer)
|
| 98 |
+
start = 0
|
| 99 |
+
for cut in cuts:
|
| 100 |
+
new_line = line[start:start+cut]
|
| 101 |
+
sentence = get_kmer_sentence(new_line, kmer=args.kmer)
|
| 102 |
+
start += cut
|
| 103 |
+
new_file.write(sentence + "\n")
|
| 104 |
+
|
| 105 |
+
line = old_file.readline()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def main():
|
| 109 |
+
parser = argparse.ArgumentParser()
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--sampling_rate",
|
| 112 |
+
default=1.0,
|
| 113 |
+
type=float,
|
| 114 |
+
help="We will sample sampling_rate*total_length*2/512 times",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--kmer",
|
| 118 |
+
default=1,
|
| 119 |
+
type=int,
|
| 120 |
+
help="K-mer",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--length",
|
| 124 |
+
default=10000,
|
| 125 |
+
type=int,
|
| 126 |
+
help="Length of the sampled sequence",
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--file_path",
|
| 130 |
+
default=None,
|
| 131 |
+
type=str,
|
| 132 |
+
help="The path of the file to be processed",
|
| 133 |
+
)
|
| 134 |
+
parser.add_argument(
|
| 135 |
+
"--output_path",
|
| 136 |
+
default=None,
|
| 137 |
+
type=str,
|
| 138 |
+
help="The path of the processed data",
|
| 139 |
+
)
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
Process(args)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
main()
|
examples/data_process_template/process_pretrain_data_multi.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Pool
|
| 2 |
+
import copy
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from process_pretrain_data import Process
|
| 6 |
+
|
| 7 |
+
# filenames = ['xaa', 'xab', 'xac', 'xad', 'xae', 'xaf', 'xag', 'xah', 'xai', 'xaj', 'xak', 'xal', 'xam', 'xan', 'xao', 'xap', 'xaq', 'xar', 'xas', 'xat', 'xau', 'xav', 'xaw']
|
| 8 |
+
# filenames = ['xaa', 'xab']
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--sampling_rate",
|
| 15 |
+
default=1.0,
|
| 16 |
+
type=float,
|
| 17 |
+
help="We will sample sampling_rate*total_length*2/512 times",
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument(
|
| 20 |
+
"--kmer",
|
| 21 |
+
default=1,
|
| 22 |
+
type=int,
|
| 23 |
+
help="K-mer",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--length",
|
| 27 |
+
default=10000,
|
| 28 |
+
type=int,
|
| 29 |
+
help="Length of the sampled sequence",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--file_path",
|
| 33 |
+
default=None,
|
| 34 |
+
type=str,
|
| 35 |
+
help="The path of the file to be processed",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--output_path",
|
| 39 |
+
default="/home/zhihan/dna/data/split/",
|
| 40 |
+
type=str,
|
| 41 |
+
help="The path of the file to be processed",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
# multiprocess
|
| 47 |
+
p = Pool(22)
|
| 48 |
+
|
| 49 |
+
for i in range(1,23):
|
| 50 |
+
arg_new = copy.deepcopy(args)
|
| 51 |
+
arg_new.file_path = "/root/data/genome/" + "GRCh38.chr" + str(i) + ".fa"
|
| 52 |
+
arg_new.output_path = "/root/data/sub_001_6140/" + "GRCh38.chr" + str(i) + ".fa"
|
| 53 |
+
# arg_new.file_path = arg_new.output_path + filename
|
| 54 |
+
p.apply_async(Process, args=(arg_new,))
|
| 55 |
+
|
| 56 |
+
p.close()
|
| 57 |
+
p.join()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
examples/data_process_template/process_scan_prom_data.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import csv
|
| 4 |
+
import numpy as np
|
| 5 |
+
from process_pretrain_data import get_kmer_sentence
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def Process(args):
|
| 11 |
+
|
| 12 |
+
SCAN_LIST = [int(500/(args.slide-1))*i for i in range(args.slide)]
|
| 13 |
+
|
| 14 |
+
old_file = open(args.file_path, "r", encoding="utf-8-sig")
|
| 15 |
+
old_lines = list(csv.reader(old_file, delimiter=",", quotechar=None))[1:]
|
| 16 |
+
|
| 17 |
+
if args.output_path:
|
| 18 |
+
root_path = args.output_path + "/"
|
| 19 |
+
else:
|
| 20 |
+
root_path = "/".join(args.file_path.split("/")[:-1]) + "/" + str(args.kmer) + "/"
|
| 21 |
+
if not os.path.exists(root_path):
|
| 22 |
+
os.makedirs(root_path)
|
| 23 |
+
|
| 24 |
+
labels = np.array([])
|
| 25 |
+
new_file = open(root_path+"dev.tsv", 'wt')
|
| 26 |
+
tsv_w = csv.writer(new_file, delimiter='\t')
|
| 27 |
+
tsv_w.writerow(["setence", "label"])
|
| 28 |
+
|
| 29 |
+
for line in old_lines:
|
| 30 |
+
label = line[6]
|
| 31 |
+
labels = np.append(labels, int(label))
|
| 32 |
+
|
| 33 |
+
for index in SCAN_LIST:
|
| 34 |
+
sub_sequence = line[8][index:index+500]
|
| 35 |
+
sub_sentence = get_kmer_sentence(sub_sequence, kmer=args.kmer)
|
| 36 |
+
tsv_w.writerow([sub_sentence, label])
|
| 37 |
+
|
| 38 |
+
np.save(root_path+"label.npy", labels)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main():
|
| 43 |
+
parser = argparse.ArgumentParser()
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--kmer",
|
| 46 |
+
default=1,
|
| 47 |
+
type=int,
|
| 48 |
+
help="K-mer",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--file_path",
|
| 52 |
+
default=None,
|
| 53 |
+
type=str,
|
| 54 |
+
help="The path of the file to be processed",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--output_path",
|
| 58 |
+
default=None,
|
| 59 |
+
type=str,
|
| 60 |
+
help="The path of the processed data",
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--slide",
|
| 64 |
+
default=11,
|
| 65 |
+
type=int,
|
| 66 |
+
help="How many 500s to use for the predictes result of 1000",
|
| 67 |
+
)
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
Process(args)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
main()
|
examples/gen_cCRE_emb_final.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
|
| 5 |
+
from Bio import SeqIO
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# ========== CONFIG ==========
|
| 9 |
+
MODEL_DIR = "/home/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000"
|
| 10 |
+
FASTA_DIR = "/home/n5huang/dna_token/cCRE_classes/chr1_files"
|
| 11 |
+
OUTPUT_DIR = "/home/n5huang/dna_token/outputs_cCREemb/"
|
| 12 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
+
|
| 16 |
+
MODEL_CLASSES = {"dna": (BertConfig, BertForMaskedLM, DNATokenizer)}
|
| 17 |
+
|
| 18 |
+
# ========== LOAD MODEL ==========
|
| 19 |
+
def load_model(model_dir):
|
| 20 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES['dna']
|
| 21 |
+
print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
|
| 22 |
+
|
| 23 |
+
config = config_class.from_pretrained(model_dir)
|
| 24 |
+
model = BertModel.from_pretrained(model_dir, config=config)
|
| 25 |
+
tokenizer = tokenizer_class.from_pretrained(model_dir)
|
| 26 |
+
|
| 27 |
+
model.to(DEVICE)
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
print(f"✅ Model loaded on {DEVICE}, vocab size = {len(tokenizer)}")
|
| 31 |
+
return model, tokenizer
|
| 32 |
+
|
| 33 |
+
# ========== SEQUENCE HELPERS ==========
|
| 34 |
+
def seq_to_kmers(seq, k=6):
|
| 35 |
+
seq = seq.upper().replace("N", "")
|
| 36 |
+
if len(seq) < k:
|
| 37 |
+
return ""
|
| 38 |
+
return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
|
| 39 |
+
|
| 40 |
+
def get_fasta_sequences(fasta_file):
|
| 41 |
+
sequences = []
|
| 42 |
+
for record in SeqIO.parse(fasta_file, "fasta"):
|
| 43 |
+
seq = str(record.seq).upper()
|
| 44 |
+
if len(seq) >= 50:
|
| 45 |
+
sequences.append(seq)
|
| 46 |
+
return sequences
|
| 47 |
+
|
| 48 |
+
# ========== EMBEDDING GENERATION ==========
|
| 49 |
+
def get_cls_embeddings(batch_seqs, model, tokenizer, device, max_len=512):
|
| 50 |
+
inputs = tokenizer.batch_encode_plus(
|
| 51 |
+
batch_seqs,
|
| 52 |
+
padding="max_length",
|
| 53 |
+
truncation=True,
|
| 54 |
+
max_length=max_len,
|
| 55 |
+
return_tensors="pt"
|
| 56 |
+
)
|
| 57 |
+
# Move tensors to device
|
| 58 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 59 |
+
|
| 60 |
+
# Forward pass
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
outputs = model(**inputs)
|
| 63 |
+
|
| 64 |
+
# Extract CLS embedding
|
| 65 |
+
cls_embeddings = outputs[0][:, 0, :].cpu().numpy()
|
| 66 |
+
return cls_embeddings
|
| 67 |
+
|
| 68 |
+
# ========== MAIN EXECUTION ==========
|
| 69 |
+
def main():
|
| 70 |
+
model, tokenizer = load_model(MODEL_DIR)
|
| 71 |
+
|
| 72 |
+
fasta_files = [f for f in os.listdir(FASTA_DIR) if f.endswith(".fa")]
|
| 73 |
+
print(f"\nFound {len(fasta_files)} FASTA files in {FASTA_DIR}")
|
| 74 |
+
|
| 75 |
+
for fasta_file in fasta_files:
|
| 76 |
+
fasta_path = os.path.join(FASTA_DIR, fasta_file)
|
| 77 |
+
print(f"\n🚀 Processing: {fasta_file}")
|
| 78 |
+
|
| 79 |
+
sequences = get_fasta_sequences(fasta_path)
|
| 80 |
+
if len(sequences) == 0:
|
| 81 |
+
print(f"⚠️ No valid sequences found in {fasta_file}")
|
| 82 |
+
continue
|
| 83 |
+
|
| 84 |
+
# --- Remove duplicates ---
|
| 85 |
+
unique_sequences = list(set(sequences))
|
| 86 |
+
if len(unique_sequences) < len(sequences):
|
| 87 |
+
print(f"⚠️ Removed {len(sequences) - len(unique_sequences)} duplicate sequences")
|
| 88 |
+
|
| 89 |
+
# --- Convert to k-mers ---
|
| 90 |
+
kmers = [seq_to_kmers(s) for s in unique_sequences if len(s) >= 6]
|
| 91 |
+
|
| 92 |
+
# --- Sanity check on tokenization ---
|
| 93 |
+
example_tokens = tokenizer.tokenize(kmers[0])[:10]
|
| 94 |
+
print(f"🔹 Example tokens: {example_tokens}")
|
| 95 |
+
|
| 96 |
+
# --- Batch embedding extraction ---
|
| 97 |
+
all_embs = []
|
| 98 |
+
batch_size = 16
|
| 99 |
+
for i in tqdm(range(0, len(kmers), batch_size), desc=f"Embedding {fasta_file}"):
|
| 100 |
+
batch = kmers[i:i+batch_size]
|
| 101 |
+
batch_embs = get_cls_embeddings(batch, model, tokenizer, DEVICE)
|
| 102 |
+
all_embs.append(batch_embs)
|
| 103 |
+
|
| 104 |
+
all_embs = np.vstack(all_embs)
|
| 105 |
+
out_path = os.path.join(OUTPUT_DIR, fasta_file.replace(".fa", "_emb.npy"))
|
| 106 |
+
np.save(out_path, all_embs)
|
| 107 |
+
|
| 108 |
+
print(f"✅ Saved {all_embs.shape} embeddings to {out_path}")
|
| 109 |
+
|
| 110 |
+
print("\n🎉 All cell-type embeddings generated successfully!")
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
main()
|
examples/load_model_test.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BertConfig, BertModel, BertForMaskedLM, DNATokenizer
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
# Define MODEL_CLASSES as it's required by your loadmodel function
|
| 7 |
+
MODEL_CLASSES = {
|
| 8 |
+
"dna": (BertConfig, BertForMaskedLM, DNATokenizer),
|
| 9 |
+
# ... (other classes omitted for brevity)
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
def loadmodel(model_dir):
|
| 13 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES['dna'] # Changed 'DNA' to 'dna' for Python keys
|
| 14 |
+
print(f"Loading using: {config_class.__name__}, {model_class.__name__}, {tokenizer_class.__name__}")
|
| 15 |
+
|
| 16 |
+
# 1. Load Configuration
|
| 17 |
+
config = config_class.from_pretrained(
|
| 18 |
+
model_dir,
|
| 19 |
+
cache_dir = None,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# 2. Load Model Weights
|
| 23 |
+
# NOTE: Since you are extracting embeddings, we should use BertModel, not BertForMaskedLM
|
| 24 |
+
# BertModel is the base transformer without the MLM head.
|
| 25 |
+
base_model_class = BertModel if model_class == BertForMaskedLM else model_class
|
| 26 |
+
|
| 27 |
+
model = base_model_class.from_pretrained(
|
| 28 |
+
model_dir,
|
| 29 |
+
from_tf=bool(".ckpt" in model_dir),
|
| 30 |
+
config=config,
|
| 31 |
+
cache_dir= None,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# 3. Set Device
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
model.to(device)
|
| 37 |
+
model.eval() # Set model to evaluation mode
|
| 38 |
+
print(f"Model loaded onto device: {device}")
|
| 39 |
+
|
| 40 |
+
# 4. Load Tokenizer (using custom environment variables)
|
| 41 |
+
#tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")}
|
| 42 |
+
#tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}} # Use 'dna' key
|
| 43 |
+
tokenizer = tokenizer_class.from_pretrained(model_dir)
|
| 44 |
+
print(f"Tokenizer vocabulary size: {len(tokenizer)}")
|
| 45 |
+
|
| 46 |
+
return config, model, tokenizer
|
| 47 |
+
|
| 48 |
+
# --- Main Call ---
|
| 49 |
+
# Use the environment variable set in the shell as the model directory
|
| 50 |
+
parser = argparse.ArgumentParser()
|
| 51 |
+
parser.add_argument("--MODEL_DIR", type=str, required=True)
|
| 52 |
+
args = parser.parse_args()
|
| 53 |
+
|
| 54 |
+
model_dir = args.MODEL_DIR
|
| 55 |
+
|
| 56 |
+
if model_dir != "/path/to/default":
|
| 57 |
+
config, model, tokenizer = loadmodel(model_dir)
|
| 58 |
+
print("Model and Tokenizer loaded successfully.")
|
| 59 |
+
|
| 60 |
+
embedding_layer = model.get_input_embeddings()
|
| 61 |
+
print(embedding_layer.weight.shape)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
seq = "ACGTACGTACGT"
|
| 65 |
+
tokens = tokenizer.tokenize(" ".join([seq[i:i+6] for i in range(len(seq)-5)]))
|
| 66 |
+
print(tokens[:10])
|
| 67 |
+
else:
|
| 68 |
+
print("Error: MODEL_DIR environment variable was not set.")
|
| 69 |
+
|
examples/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorboardX
|
| 2 |
+
tensorboard
|
| 3 |
+
scikit-learn >= 0.22.2
|
| 4 |
+
seqeval
|
| 5 |
+
pyahocorasick
|
| 6 |
+
scipy
|
| 7 |
+
statsmodels
|
| 8 |
+
biopython
|
| 9 |
+
pandas
|
| 10 |
+
pybedtools
|
| 11 |
+
sentencepiece==0.1.91
|
examples/run_finetune.py
ADDED
|
@@ -0,0 +1,1284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import glob
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
import shutil
|
| 26 |
+
import random
|
| 27 |
+
from multiprocessing import Pool
|
| 28 |
+
from typing import Dict, List, Tuple
|
| 29 |
+
from copy import deepcopy
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
| 34 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 35 |
+
from tqdm import tqdm, trange
|
| 36 |
+
|
| 37 |
+
from transformers import (
|
| 38 |
+
WEIGHTS_NAME,
|
| 39 |
+
AdamW,
|
| 40 |
+
AlbertConfig,
|
| 41 |
+
AlbertForSequenceClassification,
|
| 42 |
+
AlbertTokenizer,
|
| 43 |
+
BertConfig,
|
| 44 |
+
BertForSequenceClassification,
|
| 45 |
+
BertForLongSequenceClassification,
|
| 46 |
+
BertForLongSequenceClassificationCat,
|
| 47 |
+
BertTokenizer,
|
| 48 |
+
DNATokenizer,
|
| 49 |
+
DistilBertConfig,
|
| 50 |
+
DistilBertForSequenceClassification,
|
| 51 |
+
DistilBertTokenizer,
|
| 52 |
+
FlaubertConfig,
|
| 53 |
+
FlaubertForSequenceClassification,
|
| 54 |
+
FlaubertTokenizer,
|
| 55 |
+
RobertaConfig,
|
| 56 |
+
RobertaForSequenceClassification,
|
| 57 |
+
RobertaTokenizer,
|
| 58 |
+
XLMConfig,
|
| 59 |
+
XLMForSequenceClassification,
|
| 60 |
+
XLMRobertaConfig,
|
| 61 |
+
XLMRobertaForSequenceClassification,
|
| 62 |
+
XLMRobertaTokenizer,
|
| 63 |
+
XLMTokenizer,
|
| 64 |
+
XLNetConfig,
|
| 65 |
+
XLNetForSequenceClassification,
|
| 66 |
+
XLNetTokenizer,
|
| 67 |
+
get_linear_schedule_with_warmup,
|
| 68 |
+
)
|
| 69 |
+
from transformers import glue_compute_metrics as compute_metrics
|
| 70 |
+
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
| 71 |
+
from transformers import glue_output_modes as output_modes
|
| 72 |
+
from transformers import glue_processors as processors
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 77 |
+
except ImportError:
|
| 78 |
+
from tensorboardX import SummaryWriter
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
logger = logging.getLogger(__name__)
|
| 82 |
+
|
| 83 |
+
ALL_MODELS = sum(
|
| 84 |
+
(
|
| 85 |
+
tuple(conf.pretrained_config_archive_map.keys())
|
| 86 |
+
for conf in (
|
| 87 |
+
BertConfig,
|
| 88 |
+
XLNetConfig,
|
| 89 |
+
XLMConfig,
|
| 90 |
+
RobertaConfig,
|
| 91 |
+
DistilBertConfig,
|
| 92 |
+
AlbertConfig,
|
| 93 |
+
XLMRobertaConfig,
|
| 94 |
+
FlaubertConfig,
|
| 95 |
+
)
|
| 96 |
+
),
|
| 97 |
+
(),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
MODEL_CLASSES = {
|
| 101 |
+
"dna": (BertConfig, BertForSequenceClassification, DNATokenizer),
|
| 102 |
+
"dnalong": (BertConfig, BertForLongSequenceClassification, DNATokenizer),
|
| 103 |
+
"dnalongcat": (BertConfig, BertForLongSequenceClassificationCat, DNATokenizer),
|
| 104 |
+
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
| 105 |
+
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
| 106 |
+
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
| 107 |
+
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
| 108 |
+
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
| 109 |
+
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
| 110 |
+
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
| 111 |
+
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
TOKEN_ID_GROUP = ["bert", "dnalong", "dnalongcat", "xlnet", "albert"]
|
| 115 |
+
|
| 116 |
+
def set_seed(args):
|
| 117 |
+
random.seed(args.seed)
|
| 118 |
+
np.random.seed(args.seed)
|
| 119 |
+
torch.manual_seed(args.seed)
|
| 120 |
+
if args.n_gpu > 0:
|
| 121 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
|
| 125 |
+
ordering_and_checkpoint_path = []
|
| 126 |
+
|
| 127 |
+
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
|
| 128 |
+
|
| 129 |
+
for path in glob_checkpoints:
|
| 130 |
+
if use_mtime:
|
| 131 |
+
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
| 132 |
+
else:
|
| 133 |
+
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
|
| 134 |
+
if regex_match and regex_match.groups():
|
| 135 |
+
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
| 136 |
+
|
| 137 |
+
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
| 138 |
+
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
| 139 |
+
return checkpoints_sorted
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
|
| 143 |
+
if not args.save_total_limit:
|
| 144 |
+
return
|
| 145 |
+
if args.save_total_limit <= 0:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
# Check if we should delete older checkpoint(s)
|
| 149 |
+
checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
|
| 150 |
+
if len(checkpoints_sorted) <= args.save_total_limit:
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
|
| 154 |
+
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
| 155 |
+
for checkpoint in checkpoints_to_be_deleted:
|
| 156 |
+
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
| 157 |
+
shutil.rmtree(checkpoint)
|
| 158 |
+
|
| 159 |
+
def train(args, train_dataset, model, tokenizer):
|
| 160 |
+
""" Train the model """
|
| 161 |
+
if args.local_rank in [-1, 0]:
|
| 162 |
+
tb_writer = SummaryWriter()
|
| 163 |
+
|
| 164 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
| 165 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
| 166 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
| 167 |
+
|
| 168 |
+
if args.max_steps > 0:
|
| 169 |
+
t_total = args.max_steps
|
| 170 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
| 171 |
+
else:
|
| 172 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
| 173 |
+
|
| 174 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
| 175 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
| 176 |
+
optimizer_grouped_parameters = [
|
| 177 |
+
{
|
| 178 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 179 |
+
"weight_decay": args.weight_decay,
|
| 180 |
+
},
|
| 181 |
+
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
| 182 |
+
]
|
| 183 |
+
|
| 184 |
+
warmup_steps = args.warmup_steps if args.warmup_percent == 0 else int(args.warmup_percent*t_total)
|
| 185 |
+
|
| 186 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.beta1,args.beta2))
|
| 187 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 188 |
+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Check if saved optimizer or scheduler states exist
|
| 192 |
+
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
| 193 |
+
os.path.join(args.model_name_or_path, "scheduler.pt")
|
| 194 |
+
):
|
| 195 |
+
# Load in optimizer and scheduler states
|
| 196 |
+
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
| 197 |
+
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
| 198 |
+
|
| 199 |
+
if args.fp16:
|
| 200 |
+
try:
|
| 201 |
+
from apex import amp
|
| 202 |
+
except ImportError:
|
| 203 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
| 204 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
| 205 |
+
|
| 206 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
| 207 |
+
if args.n_gpu > 1:
|
| 208 |
+
model = torch.nn.DataParallel(model)
|
| 209 |
+
|
| 210 |
+
# Distributed training (should be after apex fp16 initialization)
|
| 211 |
+
if args.local_rank != -1:
|
| 212 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 213 |
+
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Train!
|
| 217 |
+
logger.info("***** Running training *****")
|
| 218 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
| 219 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
| 220 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
| 221 |
+
logger.info(
|
| 222 |
+
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
| 223 |
+
args.train_batch_size
|
| 224 |
+
* args.gradient_accumulation_steps
|
| 225 |
+
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
| 226 |
+
)
|
| 227 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
| 228 |
+
logger.info(" Total optimization steps = %d", t_total)
|
| 229 |
+
|
| 230 |
+
global_step = 0
|
| 231 |
+
epochs_trained = 0
|
| 232 |
+
steps_trained_in_current_epoch = 0
|
| 233 |
+
# Check if continuing training from a checkpoint
|
| 234 |
+
if os.path.exists(args.model_name_or_path):
|
| 235 |
+
# set global_step to gobal_step of last saved checkpoint from model path
|
| 236 |
+
try:
|
| 237 |
+
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
| 238 |
+
except:
|
| 239 |
+
global_step = 0
|
| 240 |
+
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
| 241 |
+
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
| 242 |
+
|
| 243 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
| 244 |
+
logger.info(" Continuing training from epoch %d", epochs_trained)
|
| 245 |
+
logger.info(" Continuing training from global step %d", global_step)
|
| 246 |
+
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
| 247 |
+
|
| 248 |
+
tr_loss, logging_loss = 0.0, 0.0
|
| 249 |
+
model.zero_grad()
|
| 250 |
+
train_iterator = trange(
|
| 251 |
+
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
|
| 252 |
+
)
|
| 253 |
+
set_seed(args) # Added here for reproductibility
|
| 254 |
+
|
| 255 |
+
best_auc = 0
|
| 256 |
+
last_auc = 0
|
| 257 |
+
stop_count = 0
|
| 258 |
+
|
| 259 |
+
for _ in train_iterator:
|
| 260 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
| 261 |
+
for step, batch in enumerate(epoch_iterator):
|
| 262 |
+
|
| 263 |
+
# Skip past any already trained steps if resuming training
|
| 264 |
+
if steps_trained_in_current_epoch > 0:
|
| 265 |
+
steps_trained_in_current_epoch -= 1
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
model.train()
|
| 269 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 270 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
| 271 |
+
if args.model_type != "distilbert":
|
| 272 |
+
inputs["token_type_ids"] = (
|
| 273 |
+
batch[2] if args.model_type in TOKEN_ID_GROUP else None
|
| 274 |
+
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
| 275 |
+
outputs = model(**inputs)
|
| 276 |
+
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
| 277 |
+
|
| 278 |
+
if args.n_gpu > 1:
|
| 279 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 280 |
+
if args.gradient_accumulation_steps > 1:
|
| 281 |
+
loss = loss / args.gradient_accumulation_steps
|
| 282 |
+
|
| 283 |
+
if args.fp16:
|
| 284 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| 285 |
+
scaled_loss.backward()
|
| 286 |
+
else:
|
| 287 |
+
loss.backward()
|
| 288 |
+
|
| 289 |
+
tr_loss += loss.item()
|
| 290 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
| 291 |
+
if args.fp16:
|
| 292 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
| 293 |
+
else:
|
| 294 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 295 |
+
|
| 296 |
+
optimizer.step()
|
| 297 |
+
scheduler.step() # Update learning rate schedule
|
| 298 |
+
model.zero_grad()
|
| 299 |
+
global_step += 1
|
| 300 |
+
|
| 301 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
| 302 |
+
logs = {}
|
| 303 |
+
if (
|
| 304 |
+
args.local_rank == -1 and args.evaluate_during_training
|
| 305 |
+
): # Only evaluate when single GPU otherwise metrics may not average well
|
| 306 |
+
results = evaluate(args, model, tokenizer)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if args.task_name == "dna690":
|
| 310 |
+
# record the best auc
|
| 311 |
+
if results["auc"] > best_auc:
|
| 312 |
+
best_auc = results["auc"]
|
| 313 |
+
|
| 314 |
+
if args.early_stop != 0:
|
| 315 |
+
# record current auc to perform early stop
|
| 316 |
+
if results["auc"] < last_auc:
|
| 317 |
+
stop_count += 1
|
| 318 |
+
else:
|
| 319 |
+
stop_count = 0
|
| 320 |
+
|
| 321 |
+
last_auc = results["auc"]
|
| 322 |
+
|
| 323 |
+
if stop_count == args.early_stop:
|
| 324 |
+
logger.info("Early stop")
|
| 325 |
+
return global_step, tr_loss / global_step
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
for key, value in results.items():
|
| 329 |
+
eval_key = "eval_{}".format(key)
|
| 330 |
+
logs[eval_key] = value
|
| 331 |
+
|
| 332 |
+
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
| 333 |
+
learning_rate_scalar = scheduler.get_lr()[0]
|
| 334 |
+
logs["learning_rate"] = learning_rate_scalar
|
| 335 |
+
logs["loss"] = loss_scalar
|
| 336 |
+
logging_loss = tr_loss
|
| 337 |
+
|
| 338 |
+
for key, value in logs.items():
|
| 339 |
+
tb_writer.add_scalar(key, value, global_step)
|
| 340 |
+
print(json.dumps({**logs, **{"step": global_step}}))
|
| 341 |
+
|
| 342 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
| 343 |
+
if args.task_name == "dna690" and results["auc"] < best_auc:
|
| 344 |
+
continue
|
| 345 |
+
checkpoint_prefix = "checkpoint"
|
| 346 |
+
# Save model checkpoint
|
| 347 |
+
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
| 348 |
+
if not os.path.exists(output_dir):
|
| 349 |
+
os.makedirs(output_dir)
|
| 350 |
+
model_to_save = (
|
| 351 |
+
model.module if hasattr(model, "module") else model
|
| 352 |
+
) # Take care of distributed/parallel training
|
| 353 |
+
model_to_save.save_pretrained(output_dir)
|
| 354 |
+
tokenizer.save_pretrained(output_dir)
|
| 355 |
+
|
| 356 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
| 357 |
+
|
| 358 |
+
_rotate_checkpoints(args, checkpoint_prefix)
|
| 359 |
+
|
| 360 |
+
if args.task_name != "dna690":
|
| 361 |
+
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
| 362 |
+
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| 363 |
+
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| 364 |
+
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
| 365 |
+
|
| 366 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
| 367 |
+
epoch_iterator.close()
|
| 368 |
+
break
|
| 369 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
| 370 |
+
train_iterator.close()
|
| 371 |
+
break
|
| 372 |
+
|
| 373 |
+
if args.local_rank in [-1, 0]:
|
| 374 |
+
tb_writer.close()
|
| 375 |
+
|
| 376 |
+
return global_step, tr_loss / global_step
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def evaluate(args, model, tokenizer, prefix="", evaluate=True):
|
| 380 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
| 381 |
+
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
| 382 |
+
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
| 383 |
+
if args.task_name[:3] == "dna":
|
| 384 |
+
softmax = torch.nn.Softmax(dim=1)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
results = {}
|
| 388 |
+
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
| 389 |
+
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=evaluate)
|
| 390 |
+
|
| 391 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
| 392 |
+
os.makedirs(eval_output_dir)
|
| 393 |
+
|
| 394 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
| 395 |
+
# Note that DistributedSampler samples randomly
|
| 396 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
| 397 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
| 398 |
+
|
| 399 |
+
# multi-gpu eval
|
| 400 |
+
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
| 401 |
+
model = torch.nn.DataParallel(model)
|
| 402 |
+
|
| 403 |
+
# Eval!
|
| 404 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
| 405 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
| 406 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
| 407 |
+
eval_loss = 0.0
|
| 408 |
+
nb_eval_steps = 0
|
| 409 |
+
preds = None
|
| 410 |
+
probs = None
|
| 411 |
+
out_label_ids = None
|
| 412 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
| 413 |
+
model.eval()
|
| 414 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 415 |
+
|
| 416 |
+
with torch.no_grad():
|
| 417 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
| 418 |
+
if args.model_type != "distilbert":
|
| 419 |
+
inputs["token_type_ids"] = (
|
| 420 |
+
batch[2] if args.model_type in TOKEN_ID_GROUP else None
|
| 421 |
+
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
| 422 |
+
outputs = model(**inputs)
|
| 423 |
+
tmp_eval_loss, logits = outputs[:2]
|
| 424 |
+
|
| 425 |
+
eval_loss += tmp_eval_loss.mean().item()
|
| 426 |
+
nb_eval_steps += 1
|
| 427 |
+
if preds is None:
|
| 428 |
+
preds = logits.detach().cpu().numpy()
|
| 429 |
+
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
| 430 |
+
else:
|
| 431 |
+
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
| 432 |
+
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
| 433 |
+
|
| 434 |
+
eval_loss = eval_loss / nb_eval_steps
|
| 435 |
+
if args.output_mode == "classification":
|
| 436 |
+
if args.task_name[:3] == "dna" and args.task_name != "dnasplice":
|
| 437 |
+
if args.do_ensemble_pred:
|
| 438 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
|
| 439 |
+
else:
|
| 440 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32))[:,1].numpy()
|
| 441 |
+
elif args.task_name == "dnasplice":
|
| 442 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
|
| 443 |
+
preds = np.argmax(preds, axis=1)
|
| 444 |
+
elif args.output_mode == "regression":
|
| 445 |
+
preds = np.squeeze(preds)
|
| 446 |
+
if args.do_ensemble_pred:
|
| 447 |
+
result = compute_metrics(eval_task, preds, out_label_ids, probs[:,1])
|
| 448 |
+
else:
|
| 449 |
+
result = compute_metrics(eval_task, preds, out_label_ids, probs)
|
| 450 |
+
results.update(result)
|
| 451 |
+
|
| 452 |
+
if args.task_name == "dna690":
|
| 453 |
+
eval_output_dir = args.result_dir
|
| 454 |
+
if not os.path.exists(args.result_dir):
|
| 455 |
+
os.makedirs(args.result_dir)
|
| 456 |
+
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
| 457 |
+
with open(output_eval_file, "a") as writer:
|
| 458 |
+
|
| 459 |
+
if args.task_name[:3] == "dna":
|
| 460 |
+
eval_result = args.data_dir.split('/')[-1] + " "
|
| 461 |
+
else:
|
| 462 |
+
eval_result = ""
|
| 463 |
+
|
| 464 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
| 465 |
+
for key in sorted(result.keys()):
|
| 466 |
+
logger.info(" %s = %s", key, str(result[key]))
|
| 467 |
+
eval_result = eval_result + str(result[key])[:5] + " "
|
| 468 |
+
writer.write(eval_result + "\n")
|
| 469 |
+
|
| 470 |
+
if args.do_ensemble_pred:
|
| 471 |
+
return results, eval_task, preds, out_label_ids, probs
|
| 472 |
+
else:
|
| 473 |
+
return results
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def predict(args, model, tokenizer, prefix=""):
|
| 478 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
| 479 |
+
pred_task_names = (args.task_name,)
|
| 480 |
+
pred_outputs_dirs = (args.predict_dir,)
|
| 481 |
+
if not os.path.exists(args.predict_dir):
|
| 482 |
+
os.makedirs(args.predict_dir)
|
| 483 |
+
softmax = torch.nn.Softmax(dim=1)
|
| 484 |
+
|
| 485 |
+
predictions = {}
|
| 486 |
+
for pred_task, pred_output_dir in zip(pred_task_names, pred_outputs_dirs):
|
| 487 |
+
pred_dataset = load_and_cache_examples(args, pred_task, tokenizer, evaluate=True)
|
| 488 |
+
|
| 489 |
+
if not os.path.exists(pred_output_dir) and args.local_rank in [-1, 0]:
|
| 490 |
+
os.makedirs(pred_output_dir)
|
| 491 |
+
|
| 492 |
+
args.pred_batch_size = args.per_gpu_pred_batch_size * max(1, args.n_gpu)
|
| 493 |
+
# Note that DistributedSampler samples randomly
|
| 494 |
+
pred_sampler = SequentialSampler(pred_dataset)
|
| 495 |
+
pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler, batch_size=args.pred_batch_size)
|
| 496 |
+
|
| 497 |
+
# multi-gpu eval
|
| 498 |
+
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
| 499 |
+
model = torch.nn.DataParallel(model)
|
| 500 |
+
|
| 501 |
+
# Eval!
|
| 502 |
+
logger.info("***** Running prediction {} *****".format(prefix))
|
| 503 |
+
logger.info(" Num examples = %d", len(pred_dataset))
|
| 504 |
+
logger.info(" Batch size = %d", args.pred_batch_size)
|
| 505 |
+
pred_loss = 0.0
|
| 506 |
+
nb_pred_steps = 0
|
| 507 |
+
preds = None
|
| 508 |
+
out_label_ids = None
|
| 509 |
+
for batch in tqdm(pred_dataloader, desc="Predicting"):
|
| 510 |
+
model.eval()
|
| 511 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 512 |
+
|
| 513 |
+
with torch.no_grad():
|
| 514 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
| 515 |
+
if args.model_type != "distilbert":
|
| 516 |
+
inputs["token_type_ids"] = (
|
| 517 |
+
batch[2] if args.model_type in TOKEN_ID_GROUP else None
|
| 518 |
+
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
| 519 |
+
outputs = model(**inputs)
|
| 520 |
+
_, logits = outputs[:2]
|
| 521 |
+
|
| 522 |
+
if preds is None:
|
| 523 |
+
preds = logits.detach().cpu().numpy()
|
| 524 |
+
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
| 525 |
+
else:
|
| 526 |
+
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
| 527 |
+
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
| 528 |
+
|
| 529 |
+
if args.output_mode == "classification":
|
| 530 |
+
if args.task_name[:3] == "dna" and args.task_name != "dnasplice":
|
| 531 |
+
if args.do_ensemble_pred:
|
| 532 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
|
| 533 |
+
else:
|
| 534 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32))[:,1].numpy()
|
| 535 |
+
elif args.task_name == "dnasplice":
|
| 536 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
|
| 537 |
+
preds = np.argmax(preds, axis=1)
|
| 538 |
+
elif args.output_mode == "regression":
|
| 539 |
+
preds = np.squeeze(preds)
|
| 540 |
+
|
| 541 |
+
if args.do_ensemble_pred:
|
| 542 |
+
result = compute_metrics(pred_task, preds, out_label_ids, probs[:,1])
|
| 543 |
+
else:
|
| 544 |
+
result = compute_metrics(pred_task, preds, out_label_ids, probs)
|
| 545 |
+
|
| 546 |
+
pred_output_dir = args.predict_dir
|
| 547 |
+
if not os.path.exists(pred_output_dir):
|
| 548 |
+
os.makedir(pred_output_dir)
|
| 549 |
+
output_pred_file = os.path.join(pred_output_dir, "pred_results.npy")
|
| 550 |
+
logger.info("***** Pred results {} *****".format(prefix))
|
| 551 |
+
for key in sorted(result.keys()):
|
| 552 |
+
logger.info(" %s = %s", key, str(result[key]))
|
| 553 |
+
np.save(output_pred_file, probs)
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def format_attention(attention):
|
| 557 |
+
squeezed = []
|
| 558 |
+
for layer_attention in attention:
|
| 559 |
+
# 1 x num_heads x seq_len x seq_len
|
| 560 |
+
if len(layer_attention.shape) != 4:
|
| 561 |
+
raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
|
| 562 |
+
"output_attentions=True when initializing your model.")
|
| 563 |
+
squeezed.append(layer_attention.squeeze(0))
|
| 564 |
+
# num_layers x num_heads x seq_len x seq_len
|
| 565 |
+
return torch.stack(squeezed).unsqueeze(0)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def visualize(args, model, tokenizer, kmer, prefix=""):
|
| 569 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
| 570 |
+
pred_task_names = (args.task_name,)
|
| 571 |
+
pred_outputs_dirs = (args.predict_dir,)
|
| 572 |
+
if not os.path.exists(args.predict_dir):
|
| 573 |
+
os.makedirs(args.predict_dir)
|
| 574 |
+
softmax = torch.nn.Softmax(dim=1)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
for pred_task, pred_output_dir in zip(pred_task_names, pred_outputs_dirs):
|
| 578 |
+
'''
|
| 579 |
+
if args.task_name != "dna690":
|
| 580 |
+
args.data_dir = os.path.join(args.visualize_data_dir, str(kmer))
|
| 581 |
+
else:
|
| 582 |
+
args.data_dir = deepcopy(args.visualize_data_dir).replace("/690", "/690/" + str(kmer))
|
| 583 |
+
'''
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
evaluate = False if args.visualize_train else True
|
| 587 |
+
pred_dataset = load_and_cache_examples(args, pred_task, tokenizer, evaluate=evaluate)
|
| 588 |
+
|
| 589 |
+
if not os.path.exists(pred_output_dir) and args.local_rank in [-1, 0]:
|
| 590 |
+
os.makedirs(pred_output_dir)
|
| 591 |
+
|
| 592 |
+
args.pred_batch_size = args.per_gpu_pred_batch_size * max(1, args.n_gpu)
|
| 593 |
+
# Note that DistributedSampler samples randomly
|
| 594 |
+
pred_sampler = SequentialSampler(pred_dataset)
|
| 595 |
+
pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler, batch_size=args.pred_batch_size)
|
| 596 |
+
|
| 597 |
+
# multi-gpu eval
|
| 598 |
+
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
| 599 |
+
model = torch.nn.DataParallel(model)
|
| 600 |
+
|
| 601 |
+
# Eval!
|
| 602 |
+
logger.info("***** Running prediction {} *****".format(prefix))
|
| 603 |
+
logger.info(" Num examples = %d", len(pred_dataset))
|
| 604 |
+
logger.info(" Batch size = %d", args.pred_batch_size)
|
| 605 |
+
pred_loss = 0.0
|
| 606 |
+
nb_pred_steps = 0
|
| 607 |
+
batch_size = args.pred_batch_size
|
| 608 |
+
if args.task_name != "dnasplice":
|
| 609 |
+
preds = np.zeros([len(pred_dataset),2])
|
| 610 |
+
else:
|
| 611 |
+
preds = np.zeros([len(pred_dataset),3])
|
| 612 |
+
attention_scores = np.zeros([len(pred_dataset), 12, args.max_seq_length, args.max_seq_length])
|
| 613 |
+
|
| 614 |
+
for index, batch in enumerate(tqdm(pred_dataloader, desc="Predicting")):
|
| 615 |
+
model.eval()
|
| 616 |
+
batch = tuple(t.to(args.device) for t in batch)
|
| 617 |
+
|
| 618 |
+
with torch.no_grad():
|
| 619 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
| 620 |
+
if args.model_type != "distilbert":
|
| 621 |
+
inputs["token_type_ids"] = (
|
| 622 |
+
batch[2] if args.model_type in TOKEN_ID_GROUP else None
|
| 623 |
+
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
| 624 |
+
outputs = model(**inputs)
|
| 625 |
+
attention = outputs[-1][-1]
|
| 626 |
+
_, logits = outputs[:2]
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
preds[index*batch_size:index*batch_size+len(batch[0]),:] = logits.detach().cpu().numpy()
|
| 630 |
+
attention_scores[index*batch_size:index*batch_size+len(batch[0]),:,:,:] = attention.cpu().numpy()
|
| 631 |
+
# if preds is None:
|
| 632 |
+
# preds = logits.detach().cpu().numpy()
|
| 633 |
+
# else:
|
| 634 |
+
# preds = np.concatenate((preds, logits.detach().cpu().numpy()), axis=0)
|
| 635 |
+
|
| 636 |
+
# if attention_scores is not None:
|
| 637 |
+
# attention_scores = np.concatenate((attention_scores, attention.cpu().numpy()), 0)
|
| 638 |
+
# else:
|
| 639 |
+
# attention_scores = attention.cpu().numpy()
|
| 640 |
+
|
| 641 |
+
if args.task_name != "dnasplice":
|
| 642 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32))[:,1].numpy()
|
| 643 |
+
else:
|
| 644 |
+
probs = softmax(torch.tensor(preds, dtype=torch.float32)).numpy()
|
| 645 |
+
|
| 646 |
+
scores = np.zeros([attention_scores.shape[0], attention_scores.shape[-1]])
|
| 647 |
+
|
| 648 |
+
for index, attention_score in enumerate(attention_scores):
|
| 649 |
+
attn_score = []
|
| 650 |
+
for i in range(1, attention_score.shape[-1]-kmer+2):
|
| 651 |
+
attn_score.append(float(attention_score[:,0,i].sum()))
|
| 652 |
+
|
| 653 |
+
for i in range(len(attn_score)-1):
|
| 654 |
+
if attn_score[i+1] == 0:
|
| 655 |
+
attn_score[i] = 0
|
| 656 |
+
break
|
| 657 |
+
|
| 658 |
+
# attn_score[0] = 0
|
| 659 |
+
counts = np.zeros([len(attn_score)+kmer-1])
|
| 660 |
+
real_scores = np.zeros([len(attn_score)+kmer-1])
|
| 661 |
+
for i, score in enumerate(attn_score):
|
| 662 |
+
for j in range(kmer):
|
| 663 |
+
counts[i+j] += 1.0
|
| 664 |
+
real_scores[i+j] += score
|
| 665 |
+
real_scores = real_scores / counts
|
| 666 |
+
real_scores = real_scores / np.linalg.norm(real_scores)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
# print(index)
|
| 670 |
+
# print(real_scores)
|
| 671 |
+
# print(len(real_scores))
|
| 672 |
+
|
| 673 |
+
scores[index] = real_scores
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
return scores, probs
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
| 681 |
+
if args.local_rank not in [-1, 0] and not evaluate:
|
| 682 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
| 683 |
+
|
| 684 |
+
processor = processors[task]()
|
| 685 |
+
output_mode = output_modes[task]
|
| 686 |
+
# Load data features from cache or dataset file
|
| 687 |
+
cached_features_file = os.path.join(
|
| 688 |
+
args.data_dir,
|
| 689 |
+
"cached_{}_{}_{}_{}".format(
|
| 690 |
+
"dev" if evaluate else "train",
|
| 691 |
+
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
| 692 |
+
str(args.max_seq_length),
|
| 693 |
+
str(task),
|
| 694 |
+
),
|
| 695 |
+
)
|
| 696 |
+
if args.do_predict:
|
| 697 |
+
cached_features_file = os.path.join(
|
| 698 |
+
args.data_dir,
|
| 699 |
+
"cached_{}_{}_{}".format(
|
| 700 |
+
"dev" if evaluate else "train",
|
| 701 |
+
str(args.max_seq_length),
|
| 702 |
+
str(task),
|
| 703 |
+
),
|
| 704 |
+
)
|
| 705 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 706 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
| 707 |
+
features = torch.load(cached_features_file)
|
| 708 |
+
else:
|
| 709 |
+
logger.info("Creating features from dataset file at %s", args.data_dir)
|
| 710 |
+
label_list = processor.get_labels()
|
| 711 |
+
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
| 712 |
+
# HACK(label indices are swapped in RoBERTa pretrained model)
|
| 713 |
+
label_list[1], label_list[2] = label_list[2], label_list[1]
|
| 714 |
+
examples = (
|
| 715 |
+
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
print("finish loading examples")
|
| 720 |
+
|
| 721 |
+
# params for convert_examples_to_features
|
| 722 |
+
max_length = args.max_seq_length
|
| 723 |
+
pad_on_left = bool(args.model_type in ["xlnet"])
|
| 724 |
+
pad_token = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
|
| 725 |
+
pad_token_segment_id = 4 if args.model_type in ["xlnet"] else 0
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
if args.n_process == 1:
|
| 729 |
+
features = convert_examples_to_features(
|
| 730 |
+
examples,
|
| 731 |
+
tokenizer,
|
| 732 |
+
label_list=label_list,
|
| 733 |
+
max_length=max_length,
|
| 734 |
+
output_mode=output_mode,
|
| 735 |
+
pad_on_left=pad_on_left, # pad on the left for xlnet
|
| 736 |
+
pad_token=pad_token,
|
| 737 |
+
pad_token_segment_id=pad_token_segment_id,)
|
| 738 |
+
|
| 739 |
+
else:
|
| 740 |
+
n_proc = int(args.n_process)
|
| 741 |
+
if evaluate:
|
| 742 |
+
n_proc = max(int(n_proc/4),1)
|
| 743 |
+
print("number of processes for converting feature: " + str(n_proc))
|
| 744 |
+
p = Pool(n_proc)
|
| 745 |
+
indexes = [0]
|
| 746 |
+
len_slice = int(len(examples)/n_proc)
|
| 747 |
+
for i in range(1, n_proc+1):
|
| 748 |
+
if i != n_proc:
|
| 749 |
+
indexes.append(len_slice*(i))
|
| 750 |
+
else:
|
| 751 |
+
indexes.append(len(examples))
|
| 752 |
+
|
| 753 |
+
results = []
|
| 754 |
+
|
| 755 |
+
for i in range(n_proc):
|
| 756 |
+
results.append(p.apply_async(convert_examples_to_features, args=(examples[indexes[i]:indexes[i+1]], tokenizer, max_length, None, label_list, output_mode, pad_on_left, pad_token, pad_token_segment_id, True, )))
|
| 757 |
+
print(str(i+1) + ' processor started !')
|
| 758 |
+
|
| 759 |
+
p.close()
|
| 760 |
+
p.join()
|
| 761 |
+
|
| 762 |
+
features = []
|
| 763 |
+
for result in results:
|
| 764 |
+
features.extend(result.get())
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
if args.local_rank in [-1, 0]:
|
| 768 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
| 769 |
+
torch.save(features, cached_features_file)
|
| 770 |
+
|
| 771 |
+
if args.local_rank == 0 and not evaluate:
|
| 772 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
| 773 |
+
|
| 774 |
+
# Convert to Tensors and build dataset
|
| 775 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
| 776 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
| 777 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
| 778 |
+
if output_mode == "classification":
|
| 779 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
| 780 |
+
elif output_mode == "regression":
|
| 781 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
| 782 |
+
|
| 783 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
| 784 |
+
return dataset
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def main():
|
| 788 |
+
parser = argparse.ArgumentParser()
|
| 789 |
+
|
| 790 |
+
# Required parameters
|
| 791 |
+
parser.add_argument(
|
| 792 |
+
"--data_dir",
|
| 793 |
+
default=None,
|
| 794 |
+
type=str,
|
| 795 |
+
required=True,
|
| 796 |
+
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
| 797 |
+
)
|
| 798 |
+
parser.add_argument(
|
| 799 |
+
"--model_type",
|
| 800 |
+
default=None,
|
| 801 |
+
type=str,
|
| 802 |
+
required=True,
|
| 803 |
+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
| 804 |
+
)
|
| 805 |
+
parser.add_argument(
|
| 806 |
+
"--n_process",
|
| 807 |
+
default=2,
|
| 808 |
+
type=int,
|
| 809 |
+
help="number of processes used for data process",
|
| 810 |
+
)
|
| 811 |
+
parser.add_argument(
|
| 812 |
+
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
|
| 813 |
+
)
|
| 814 |
+
parser.add_argument(
|
| 815 |
+
"--model_name_or_path",
|
| 816 |
+
default=None,
|
| 817 |
+
type=str,
|
| 818 |
+
required=True,
|
| 819 |
+
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
| 820 |
+
)
|
| 821 |
+
parser.add_argument(
|
| 822 |
+
"--task_name",
|
| 823 |
+
default=None,
|
| 824 |
+
type=str,
|
| 825 |
+
required=True,
|
| 826 |
+
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
|
| 827 |
+
)
|
| 828 |
+
parser.add_argument(
|
| 829 |
+
"--output_dir",
|
| 830 |
+
default=None,
|
| 831 |
+
type=str,
|
| 832 |
+
required=True,
|
| 833 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
# Other parameters
|
| 838 |
+
parser.add_argument(
|
| 839 |
+
"--visualize_data_dir",
|
| 840 |
+
default=None,
|
| 841 |
+
type=str,
|
| 842 |
+
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
| 843 |
+
)
|
| 844 |
+
parser.add_argument(
|
| 845 |
+
"--result_dir",
|
| 846 |
+
default=None,
|
| 847 |
+
type=str,
|
| 848 |
+
help="The directory where the dna690 and mouse will save results.",
|
| 849 |
+
)
|
| 850 |
+
parser.add_argument(
|
| 851 |
+
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
|
| 852 |
+
)
|
| 853 |
+
parser.add_argument(
|
| 854 |
+
"--tokenizer_name",
|
| 855 |
+
default="",
|
| 856 |
+
type=str,
|
| 857 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 858 |
+
)
|
| 859 |
+
parser.add_argument(
|
| 860 |
+
"--cache_dir",
|
| 861 |
+
default="",
|
| 862 |
+
type=str,
|
| 863 |
+
help="Where do you want to store the pre-trained models downloaded from s3",
|
| 864 |
+
)
|
| 865 |
+
parser.add_argument(
|
| 866 |
+
"--predict_dir",
|
| 867 |
+
default=None,
|
| 868 |
+
type=str,
|
| 869 |
+
help="The output directory of predicted result. (when do_predict)",
|
| 870 |
+
)
|
| 871 |
+
parser.add_argument(
|
| 872 |
+
"--max_seq_length",
|
| 873 |
+
default=128,
|
| 874 |
+
type=int,
|
| 875 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
| 876 |
+
"than this will be truncated, sequences shorter will be padded.",
|
| 877 |
+
)
|
| 878 |
+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
| 879 |
+
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
| 880 |
+
parser.add_argument("--do_predict", action="store_true", help="Whether to do prediction on the given dataset.")
|
| 881 |
+
parser.add_argument("--do_visualize", action="store_true", help="Whether to calculate attention score.")
|
| 882 |
+
parser.add_argument("--visualize_train", action="store_true", help="Whether to visualize train.tsv or dev.tsv.")
|
| 883 |
+
parser.add_argument("--do_ensemble_pred", action="store_true", help="Whether to do ensemble prediction with kmer 3456.")
|
| 884 |
+
parser.add_argument(
|
| 885 |
+
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
|
| 886 |
+
)
|
| 887 |
+
parser.add_argument(
|
| 888 |
+
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
parser.add_argument(
|
| 892 |
+
"--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
|
| 893 |
+
)
|
| 894 |
+
parser.add_argument(
|
| 895 |
+
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
|
| 896 |
+
)
|
| 897 |
+
parser.add_argument(
|
| 898 |
+
"--per_gpu_pred_batch_size", default=8, type=int, help="Batch size per GPU/CPU for prediction.",
|
| 899 |
+
)
|
| 900 |
+
parser.add_argument(
|
| 901 |
+
"--early_stop", default=0, type=int, help="set this to a positive integet if you want to perfrom early stop. The model will stop \
|
| 902 |
+
if the auc keep decreasing early_stop times",
|
| 903 |
+
)
|
| 904 |
+
parser.add_argument(
|
| 905 |
+
"--predict_scan_size",
|
| 906 |
+
type=int,
|
| 907 |
+
default=1,
|
| 908 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 909 |
+
)
|
| 910 |
+
parser.add_argument(
|
| 911 |
+
"--gradient_accumulation_steps",
|
| 912 |
+
type=int,
|
| 913 |
+
default=1,
|
| 914 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 915 |
+
)
|
| 916 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
| 917 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
| 918 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
| 919 |
+
parser.add_argument("--beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
|
| 920 |
+
parser.add_argument("--beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
|
| 921 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 922 |
+
parser.add_argument("--attention_probs_dropout_prob", default=0.1, type=float, help="Dropout rate of attention.")
|
| 923 |
+
parser.add_argument("--hidden_dropout_prob", default=0.1, type=float, help="Dropout rate of intermidiete layer.")
|
| 924 |
+
parser.add_argument("--rnn_dropout", default=0.0, type=float, help="Dropout rate of intermidiete layer.")
|
| 925 |
+
parser.add_argument("--rnn", default="lstm", type=str, help="What kind of RNN to use")
|
| 926 |
+
parser.add_argument("--num_rnn_layer", default=2, type=int, help="Number of rnn layers in dnalong model.")
|
| 927 |
+
parser.add_argument("--rnn_hidden", default=768, type=int, help="Number of hidden unit in a rnn layer.")
|
| 928 |
+
parser.add_argument(
|
| 929 |
+
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
|
| 930 |
+
)
|
| 931 |
+
parser.add_argument(
|
| 932 |
+
"--max_steps",
|
| 933 |
+
default=-1,
|
| 934 |
+
type=int,
|
| 935 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
| 936 |
+
)
|
| 937 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
| 938 |
+
parser.add_argument("--warmup_percent", default=0, type=float, help="Linear warmup over warmup_percent*total_steps.")
|
| 939 |
+
|
| 940 |
+
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
| 941 |
+
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
| 942 |
+
parser.add_argument(
|
| 943 |
+
"--save_total_limit",
|
| 944 |
+
type=int,
|
| 945 |
+
default=None,
|
| 946 |
+
help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
|
| 947 |
+
)
|
| 948 |
+
parser.add_argument(
|
| 949 |
+
"--eval_all_checkpoints",
|
| 950 |
+
action="store_true",
|
| 951 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
| 952 |
+
)
|
| 953 |
+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
| 954 |
+
parser.add_argument(
|
| 955 |
+
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
|
| 956 |
+
)
|
| 957 |
+
parser.add_argument(
|
| 958 |
+
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
|
| 959 |
+
)
|
| 960 |
+
parser.add_argument(
|
| 961 |
+
"--visualize_models", type=int, default=None, help="The model used to do visualization. If None, use 3456.",
|
| 962 |
+
)
|
| 963 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
parser.add_argument(
|
| 967 |
+
"--fp16",
|
| 968 |
+
action="store_true",
|
| 969 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
| 970 |
+
)
|
| 971 |
+
parser.add_argument(
|
| 972 |
+
"--fp16_opt_level",
|
| 973 |
+
type=str,
|
| 974 |
+
default="O1",
|
| 975 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
| 976 |
+
"See details at https://nvidia.github.io/apex/amp.html",
|
| 977 |
+
)
|
| 978 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 979 |
+
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
| 980 |
+
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
args = parser.parse_args()
|
| 984 |
+
|
| 985 |
+
if args.should_continue:
|
| 986 |
+
sorted_checkpoints = _sorted_checkpoints(args)
|
| 987 |
+
if len(sorted_checkpoints) == 0:
|
| 988 |
+
raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
|
| 989 |
+
else:
|
| 990 |
+
args.model_name_or_path = sorted_checkpoints[-1]
|
| 991 |
+
|
| 992 |
+
if (
|
| 993 |
+
os.path.exists(args.output_dir)
|
| 994 |
+
and os.listdir(args.output_dir)
|
| 995 |
+
and args.do_train
|
| 996 |
+
and not args.overwrite_output_dir
|
| 997 |
+
):
|
| 998 |
+
raise ValueError(
|
| 999 |
+
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
| 1000 |
+
args.output_dir
|
| 1001 |
+
)
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
# Setup distant debugging if needed
|
| 1005 |
+
if args.server_ip and args.server_port:
|
| 1006 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
| 1007 |
+
import ptvsd
|
| 1008 |
+
|
| 1009 |
+
print("Waiting for debugger attach")
|
| 1010 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
| 1011 |
+
ptvsd.wait_for_attach()
|
| 1012 |
+
|
| 1013 |
+
# Setup CUDA, GPU & distributed training
|
| 1014 |
+
if args.local_rank == -1 or args.no_cuda:
|
| 1015 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| 1016 |
+
args.n_gpu = torch.cuda.device_count()
|
| 1017 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
| 1018 |
+
torch.cuda.set_device(args.local_rank)
|
| 1019 |
+
device = torch.device("cuda", args.local_rank)
|
| 1020 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 1021 |
+
args.n_gpu = 1
|
| 1022 |
+
args.device = device
|
| 1023 |
+
|
| 1024 |
+
# Setup logging
|
| 1025 |
+
logging.basicConfig(
|
| 1026 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 1027 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 1028 |
+
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
| 1029 |
+
)
|
| 1030 |
+
logger.warning(
|
| 1031 |
+
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
| 1032 |
+
args.local_rank,
|
| 1033 |
+
device,
|
| 1034 |
+
args.n_gpu,
|
| 1035 |
+
bool(args.local_rank != -1),
|
| 1036 |
+
args.fp16,
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
# Set seed
|
| 1040 |
+
set_seed(args)
|
| 1041 |
+
|
| 1042 |
+
# Prepare GLUE task
|
| 1043 |
+
args.task_name = args.task_name.lower()
|
| 1044 |
+
if args.task_name not in processors:
|
| 1045 |
+
raise ValueError("Task not found: %s" % (args.task_name))
|
| 1046 |
+
processor = processors[args.task_name]()
|
| 1047 |
+
args.output_mode = output_modes[args.task_name]
|
| 1048 |
+
label_list = processor.get_labels()
|
| 1049 |
+
num_labels = len(label_list)
|
| 1050 |
+
|
| 1051 |
+
# Load pretrained model and tokenizer
|
| 1052 |
+
if args.local_rank not in [-1, 0]:
|
| 1053 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
| 1054 |
+
|
| 1055 |
+
args.model_type = args.model_type.lower()
|
| 1056 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
| 1057 |
+
|
| 1058 |
+
if not args.do_visualize and not args.do_ensemble_pred:
|
| 1059 |
+
config = config_class.from_pretrained(
|
| 1060 |
+
args.config_name if args.config_name else args.model_name_or_path,
|
| 1061 |
+
num_labels=num_labels,
|
| 1062 |
+
finetuning_task=args.task_name,
|
| 1063 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
config.hidden_dropout_prob = args.hidden_dropout_prob
|
| 1067 |
+
config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
|
| 1068 |
+
if args.model_type in ["dnalong", "dnalongcat"]:
|
| 1069 |
+
assert args.max_seq_length % 512 == 0
|
| 1070 |
+
config.split = int(args.max_seq_length/512)
|
| 1071 |
+
config.rnn = args.rnn
|
| 1072 |
+
config.num_rnn_layer = args.num_rnn_layer
|
| 1073 |
+
config.rnn_dropout = args.rnn_dropout
|
| 1074 |
+
config.rnn_hidden = args.rnn_hidden
|
| 1075 |
+
|
| 1076 |
+
tokenizer = tokenizer_class.from_pretrained(
|
| 1077 |
+
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
| 1078 |
+
do_lower_case=args.do_lower_case,
|
| 1079 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1080 |
+
)
|
| 1081 |
+
model = model_class.from_pretrained(
|
| 1082 |
+
args.model_name_or_path,
|
| 1083 |
+
from_tf=bool(".ckpt" in args.model_name_or_path),
|
| 1084 |
+
config=config,
|
| 1085 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1086 |
+
)
|
| 1087 |
+
logger.info('finish loading model')
|
| 1088 |
+
|
| 1089 |
+
if args.local_rank == 0:
|
| 1090 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
| 1091 |
+
|
| 1092 |
+
model.to(args.device)
|
| 1093 |
+
|
| 1094 |
+
logger.info("Training/evaluation parameters %s", args)
|
| 1095 |
+
|
| 1096 |
+
# Training
|
| 1097 |
+
if args.do_train:
|
| 1098 |
+
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
| 1099 |
+
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
| 1100 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
| 1101 |
+
|
| 1102 |
+
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
| 1103 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and args.task_name != "dna690":
|
| 1104 |
+
# Create output directory if needed
|
| 1105 |
+
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
| 1106 |
+
os.makedirs(args.output_dir)
|
| 1107 |
+
|
| 1108 |
+
logger.info("Saving model checkpoint to %s", args.output_dir)
|
| 1109 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
| 1110 |
+
# They can then be reloaded using `from_pretrained()`
|
| 1111 |
+
model_to_save = (
|
| 1112 |
+
model.module if hasattr(model, "module") else model
|
| 1113 |
+
) # Take care of distributed/parallel training
|
| 1114 |
+
model_to_save.save_pretrained(args.output_dir)
|
| 1115 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 1116 |
+
|
| 1117 |
+
# Good practice: save your training arguments together with the trained model
|
| 1118 |
+
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
| 1119 |
+
|
| 1120 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
| 1121 |
+
model = model_class.from_pretrained(args.output_dir)
|
| 1122 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
| 1123 |
+
model.to(args.device)
|
| 1124 |
+
|
| 1125 |
+
# Evaluation
|
| 1126 |
+
results = {}
|
| 1127 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
| 1128 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
| 1129 |
+
checkpoints = [args.output_dir]
|
| 1130 |
+
if args.eval_all_checkpoints:
|
| 1131 |
+
checkpoints = list(
|
| 1132 |
+
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
| 1133 |
+
)
|
| 1134 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
| 1135 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
| 1136 |
+
for checkpoint in checkpoints:
|
| 1137 |
+
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
| 1138 |
+
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
| 1139 |
+
|
| 1140 |
+
model = model_class.from_pretrained(checkpoint)
|
| 1141 |
+
model.to(args.device)
|
| 1142 |
+
result = evaluate(args, model, tokenizer, prefix=prefix)
|
| 1143 |
+
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
| 1144 |
+
results.update(result)
|
| 1145 |
+
|
| 1146 |
+
# Prediction
|
| 1147 |
+
predictions = {}
|
| 1148 |
+
if args.do_predict and args.local_rank in [-1, 0]:
|
| 1149 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
| 1150 |
+
checkpoint = args.output_dir
|
| 1151 |
+
logger.info("Predict using the following checkpoint: %s", checkpoint)
|
| 1152 |
+
prefix = ''
|
| 1153 |
+
model = model_class.from_pretrained(checkpoint)
|
| 1154 |
+
model.to(args.device)
|
| 1155 |
+
prediction = predict(args, model, tokenizer, prefix=prefix)
|
| 1156 |
+
|
| 1157 |
+
# Visualize
|
| 1158 |
+
if args.do_visualize and args.local_rank in [-1, 0]:
|
| 1159 |
+
visualization_models = [3,4,5,6] if not args.visualize_models else [args.visualize_models]
|
| 1160 |
+
|
| 1161 |
+
scores = None
|
| 1162 |
+
all_probs = None
|
| 1163 |
+
|
| 1164 |
+
for kmer in visualization_models:
|
| 1165 |
+
output_dir = args.output_dir.replace("/690", "/690/" + str(kmer))
|
| 1166 |
+
#checkpoint_name = os.listdir(output_dir)[0]
|
| 1167 |
+
#output_dir = os.path.join(output_dir, checkpoint_name)
|
| 1168 |
+
|
| 1169 |
+
tokenizer = tokenizer_class.from_pretrained(
|
| 1170 |
+
"dna"+str(kmer),
|
| 1171 |
+
do_lower_case=args.do_lower_case,
|
| 1172 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1173 |
+
)
|
| 1174 |
+
checkpoint = output_dir
|
| 1175 |
+
logger.info("Calculate attention score using the following checkpoint: %s", checkpoint)
|
| 1176 |
+
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
| 1177 |
+
config = config_class.from_pretrained(
|
| 1178 |
+
output_dir,
|
| 1179 |
+
num_labels=num_labels,
|
| 1180 |
+
finetuning_task=args.task_name,
|
| 1181 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1182 |
+
)
|
| 1183 |
+
config.output_attentions = True
|
| 1184 |
+
model = model_class.from_pretrained(
|
| 1185 |
+
checkpoint,
|
| 1186 |
+
from_tf=bool(".ckpt" in args.model_name_or_path),
|
| 1187 |
+
config=config,
|
| 1188 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1189 |
+
)
|
| 1190 |
+
model.to(args.device)
|
| 1191 |
+
attention_scores, probs = visualize(args, model, tokenizer, prefix=prefix, kmer=kmer)
|
| 1192 |
+
if scores is not None:
|
| 1193 |
+
all_probs += probs
|
| 1194 |
+
scores += attention_scores
|
| 1195 |
+
else:
|
| 1196 |
+
all_probs = deepcopy(probs)
|
| 1197 |
+
scores = deepcopy(attention_scores)
|
| 1198 |
+
|
| 1199 |
+
all_probs = all_probs/float(len(visualization_models))
|
| 1200 |
+
np.save(os.path.join(args.predict_dir, "atten.npy"), scores)
|
| 1201 |
+
np.save(os.path.join(args.predict_dir, "pred_results.npy"), all_probs)
|
| 1202 |
+
|
| 1203 |
+
# ensemble prediction
|
| 1204 |
+
if args.do_ensemble_pred and args.local_rank in [-1, 0]:
|
| 1205 |
+
|
| 1206 |
+
for kmer in range(3,7):
|
| 1207 |
+
output_dir = os.path.join(args.output_dir, str(kmer))
|
| 1208 |
+
tokenizer = tokenizer_class.from_pretrained(
|
| 1209 |
+
"dna"+str(kmer),
|
| 1210 |
+
do_lower_case=args.do_lower_case,
|
| 1211 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1212 |
+
)
|
| 1213 |
+
checkpoint = output_dir
|
| 1214 |
+
logger.info("Calculate attention score using the following checkpoint: %s", checkpoint)
|
| 1215 |
+
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
| 1216 |
+
config = config_class.from_pretrained(
|
| 1217 |
+
output_dir,
|
| 1218 |
+
num_labels=num_labels,
|
| 1219 |
+
finetuning_task=args.task_name,
|
| 1220 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1221 |
+
)
|
| 1222 |
+
config.output_attentions = True
|
| 1223 |
+
model = model_class.from_pretrained(
|
| 1224 |
+
args.model_name_or_path,
|
| 1225 |
+
from_tf=bool(".ckpt" in args.model_name_or_path),
|
| 1226 |
+
config=config,
|
| 1227 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
| 1228 |
+
)
|
| 1229 |
+
model.to(args.device)
|
| 1230 |
+
if kmer == 3:
|
| 1231 |
+
args.data_dir = os.path.join(args.data_dir, str(kmer))
|
| 1232 |
+
else:
|
| 1233 |
+
args.data_dir = args.data_dir.replace("/"+str(kmer-1), "/"+str(kmer))
|
| 1234 |
+
|
| 1235 |
+
if args.result_dir.split('/')[-1] == "test.npy":
|
| 1236 |
+
results, eval_task, _, out_label_ids, probs = evaluate(args, model, tokenizer, prefix=prefix)
|
| 1237 |
+
elif args.result_dir.split('/')[-1] == "train.npy":
|
| 1238 |
+
results, eval_task, _, out_label_ids, probs = evaluate(args, model, tokenizer, prefix=prefix, evaluate=False)
|
| 1239 |
+
else:
|
| 1240 |
+
raise ValueError("file name in result_dir should be either test.npy or train.npy")
|
| 1241 |
+
|
| 1242 |
+
if kmer == 3:
|
| 1243 |
+
all_probs = deepcopy(probs)
|
| 1244 |
+
cat_probs = deepcopy(probs)
|
| 1245 |
+
else:
|
| 1246 |
+
all_probs += probs
|
| 1247 |
+
cat_probs = np.concatenate((cat_probs, probs), axis=1)
|
| 1248 |
+
print(cat_probs[0])
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
all_probs = all_probs / 4.0
|
| 1252 |
+
all_preds = np.argmax(all_probs, axis=1)
|
| 1253 |
+
|
| 1254 |
+
# save label and data for stuck ensemble
|
| 1255 |
+
labels = np.array(out_label_ids)
|
| 1256 |
+
labels = labels.reshape(labels.shape[0],1)
|
| 1257 |
+
data = np.concatenate((cat_probs, labels), axis=1)
|
| 1258 |
+
random.shuffle(data)
|
| 1259 |
+
root_path = args.result_dir.replace(args.result_dir.split('/')[-1],'')
|
| 1260 |
+
if not os.path.exists(root_path):
|
| 1261 |
+
os.makedirs(root_path)
|
| 1262 |
+
# data_path = os.path.join(root_path, "data")
|
| 1263 |
+
# pred_path = os.path.join(root_path, "pred")
|
| 1264 |
+
# if not os.path.exists(data_path):
|
| 1265 |
+
# os.makedirs(data_path)
|
| 1266 |
+
# if not os.path.exists(pred_path):
|
| 1267 |
+
# os.makedirs(pred_path)
|
| 1268 |
+
# np.save(os.path.join(data_path, args.result_dir.split('/')[-1]), data)
|
| 1269 |
+
# np.save(os.path.join(pred_path, "pred_results.npy", all_probs[:,1]))
|
| 1270 |
+
np.save(args.result_dir, data)
|
| 1271 |
+
ensemble_results = compute_metrics(eval_task, all_preds, out_label_ids, all_probs[:,1])
|
| 1272 |
+
logger.info("***** Ensemble results {} *****".format(prefix))
|
| 1273 |
+
for key in sorted(ensemble_results.keys()):
|
| 1274 |
+
logger.info(" %s = %s", key, str(ensemble_results[key]))
|
| 1275 |
+
|
| 1276 |
+
|
| 1277 |
+
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
return results
|
| 1281 |
+
|
| 1282 |
+
|
| 1283 |
+
if __name__ == "__main__":
|
| 1284 |
+
main()
|
examples/run_pretrain.py
ADDED
|
@@ -0,0 +1,885 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
|
| 18 |
+
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
|
| 19 |
+
using a masked language modeling (MLM) loss.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import glob
|
| 25 |
+
import logging
|
| 26 |
+
import os
|
| 27 |
+
import pickle
|
| 28 |
+
import random
|
| 29 |
+
import re
|
| 30 |
+
import shutil
|
| 31 |
+
from typing import Dict, List, Tuple
|
| 32 |
+
from copy import deepcopy
|
| 33 |
+
from multiprocessing import Pool
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import torch
|
| 37 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 38 |
+
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
| 39 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 40 |
+
from tqdm import tqdm, trange
|
| 41 |
+
|
| 42 |
+
from transformers import (
|
| 43 |
+
WEIGHTS_NAME,
|
| 44 |
+
AdamW,
|
| 45 |
+
BertConfig,
|
| 46 |
+
BertForMaskedLM,
|
| 47 |
+
BertTokenizer,
|
| 48 |
+
DNATokenizer,
|
| 49 |
+
CamembertConfig,
|
| 50 |
+
CamembertForMaskedLM,
|
| 51 |
+
CamembertTokenizer,
|
| 52 |
+
DistilBertConfig,
|
| 53 |
+
DistilBertForMaskedLM,
|
| 54 |
+
DistilBertTokenizer,
|
| 55 |
+
GPT2Config,
|
| 56 |
+
GPT2LMHeadModel,
|
| 57 |
+
GPT2Tokenizer,
|
| 58 |
+
OpenAIGPTConfig,
|
| 59 |
+
OpenAIGPTLMHeadModel,
|
| 60 |
+
OpenAIGPTTokenizer,
|
| 61 |
+
PreTrainedModel,
|
| 62 |
+
PreTrainedTokenizer,
|
| 63 |
+
RobertaConfig,
|
| 64 |
+
RobertaForMaskedLM,
|
| 65 |
+
RobertaTokenizer,
|
| 66 |
+
get_linear_schedule_with_warmup,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 72 |
+
except ImportError:
|
| 73 |
+
from tensorboardX import SummaryWriter
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
logger = logging.getLogger(__name__)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
MODEL_CLASSES = {
|
| 80 |
+
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
| 81 |
+
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
| 82 |
+
"dna": (BertConfig, BertForMaskedLM, DNATokenizer),
|
| 83 |
+
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
| 84 |
+
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
| 85 |
+
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
| 86 |
+
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
MASK_LIST = {
|
| 90 |
+
"3": [-1, 1],
|
| 91 |
+
"4": [-1, 1, 2],
|
| 92 |
+
"5": [-2, -1, 1, 2],
|
| 93 |
+
"6": [-2, -1, 1, 2, 3]
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class TextDataset(Dataset):
|
| 98 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
|
| 99 |
+
assert os.path.isfile(file_path)
|
| 100 |
+
|
| 101 |
+
block_size = block_size - (tokenizer.max_len - tokenizer.max_len_single_sentence)
|
| 102 |
+
|
| 103 |
+
directory, filename = os.path.split(file_path)
|
| 104 |
+
cached_features_file = os.path.join(
|
| 105 |
+
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 109 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
| 110 |
+
with open(cached_features_file, "rb") as handle:
|
| 111 |
+
self.examples = pickle.load(handle)
|
| 112 |
+
else:
|
| 113 |
+
logger.info("Creating features from dataset file at %s", directory)
|
| 114 |
+
|
| 115 |
+
self.examples = []
|
| 116 |
+
with open(file_path, encoding="utf-8") as f:
|
| 117 |
+
text = f.read()
|
| 118 |
+
|
| 119 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
| 120 |
+
|
| 121 |
+
for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
|
| 122 |
+
self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size]))
|
| 123 |
+
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
|
| 124 |
+
# If your dataset is small, first you should loook for a bigger one :-) and second you
|
| 125 |
+
# can change this behavior by adding (model specific) padding.
|
| 126 |
+
|
| 127 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
| 128 |
+
with open(cached_features_file, "wb") as handle:
|
| 129 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 130 |
+
|
| 131 |
+
def __len__(self):
|
| 132 |
+
return len(self.examples)
|
| 133 |
+
|
| 134 |
+
def __getitem__(self, item):
|
| 135 |
+
return torch.tensor(self.examples[item], dtype=torch.long)
|
| 136 |
+
|
| 137 |
+
def convert_line_to_example(tokenizer, lines, max_length, add_special_tokens=True):
|
| 138 |
+
examples = tokenizer.batch_encode_plus(lines, add_special_tokens=add_special_tokens, max_length=max_length)["input_ids"]
|
| 139 |
+
return examples
|
| 140 |
+
|
| 141 |
+
class LineByLineTextDataset(Dataset):
|
| 142 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
|
| 143 |
+
assert os.path.isfile(file_path)
|
| 144 |
+
# Here, we do not cache the features, operating under the assumption
|
| 145 |
+
# that we will soon use fast multithreaded tokenizers from the
|
| 146 |
+
# `tokenizers` repo everywhere =)
|
| 147 |
+
directory, filename = os.path.split(file_path)
|
| 148 |
+
cached_features_file = os.path.join(
|
| 149 |
+
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
| 153 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
| 154 |
+
with open(cached_features_file, "rb") as handle:
|
| 155 |
+
self.examples = pickle.load(handle)
|
| 156 |
+
else:
|
| 157 |
+
logger.info("Creating features from dataset file at %s", file_path)
|
| 158 |
+
|
| 159 |
+
with open(file_path, encoding="utf-8") as f:
|
| 160 |
+
lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
| 161 |
+
|
| 162 |
+
if args.n_process == 1:
|
| 163 |
+
self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
|
| 164 |
+
else:
|
| 165 |
+
n_proc = args.n_process
|
| 166 |
+
p = Pool(n_proc)
|
| 167 |
+
indexes = [0]
|
| 168 |
+
len_slice = int(len(lines)/n_proc)
|
| 169 |
+
for i in range(1, n_proc+1):
|
| 170 |
+
if i != n_proc:
|
| 171 |
+
indexes.append(len_slice*(i))
|
| 172 |
+
else:
|
| 173 |
+
indexes.append(len(lines))
|
| 174 |
+
results = []
|
| 175 |
+
for i in range(n_proc):
|
| 176 |
+
results.append(p.apply_async(convert_line_to_example,[tokenizer, lines[indexes[i]:indexes[i+1]], block_size,]))
|
| 177 |
+
print(str(i) + " start")
|
| 178 |
+
p.close()
|
| 179 |
+
p.join()
|
| 180 |
+
|
| 181 |
+
self.examples = []
|
| 182 |
+
for result in results:
|
| 183 |
+
ids = result.get()
|
| 184 |
+
self.examples.extend(ids)
|
| 185 |
+
|
| 186 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
| 187 |
+
with open(cached_features_file, "wb") as handle:
|
| 188 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 189 |
+
|
| 190 |
+
def __len__(self):
|
| 191 |
+
return len(self.examples)
|
| 192 |
+
|
| 193 |
+
def __getitem__(self, i):
|
| 194 |
+
return torch.tensor(self.examples[i], dtype=torch.long)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
| 198 |
+
file_path = args.eval_data_file if evaluate else args.train_data_file
|
| 199 |
+
if args.line_by_line:
|
| 200 |
+
return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
|
| 201 |
+
else:
|
| 202 |
+
return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def set_seed(args):
|
| 206 |
+
random.seed(args.seed)
|
| 207 |
+
np.random.seed(args.seed)
|
| 208 |
+
torch.manual_seed(args.seed)
|
| 209 |
+
if args.n_gpu > 0:
|
| 210 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
|
| 214 |
+
ordering_and_checkpoint_path = []
|
| 215 |
+
|
| 216 |
+
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
|
| 217 |
+
|
| 218 |
+
for path in glob_checkpoints:
|
| 219 |
+
if use_mtime:
|
| 220 |
+
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
| 221 |
+
else:
|
| 222 |
+
regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
|
| 223 |
+
if regex_match and regex_match.groups():
|
| 224 |
+
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
| 225 |
+
|
| 226 |
+
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
| 227 |
+
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
| 228 |
+
return checkpoints_sorted
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
|
| 232 |
+
if not args.save_total_limit:
|
| 233 |
+
return
|
| 234 |
+
if args.save_total_limit <= 0:
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
# Check if we should delete older checkpoint(s)
|
| 238 |
+
checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
|
| 239 |
+
if len(checkpoints_sorted) <= args.save_total_limit:
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
|
| 243 |
+
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
| 244 |
+
for checkpoint in checkpoints_to_be_deleted:
|
| 245 |
+
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
| 246 |
+
shutil.rmtree(checkpoint)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 252 |
+
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
|
| 253 |
+
|
| 254 |
+
mask_list = MASK_LIST[tokenizer.kmer]
|
| 255 |
+
|
| 256 |
+
if tokenizer.mask_token is None:
|
| 257 |
+
raise ValueError(
|
| 258 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
labels = inputs.clone()
|
| 262 |
+
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
|
| 263 |
+
probability_matrix = torch.full(labels.shape, args.mlm_probability)
|
| 264 |
+
special_tokens_mask = [
|
| 265 |
+
tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
|
| 266 |
+
]
|
| 267 |
+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
|
| 268 |
+
if tokenizer._pad_token is not None:
|
| 269 |
+
padding_mask = labels.eq(tokenizer.pad_token_id)
|
| 270 |
+
probability_matrix.masked_fill_(padding_mask, value=0.0)
|
| 271 |
+
|
| 272 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
| 273 |
+
|
| 274 |
+
# change masked indices
|
| 275 |
+
masks = deepcopy(masked_indices)
|
| 276 |
+
for i, masked_index in enumerate(masks):
|
| 277 |
+
end = torch.where(probability_matrix[i]!=0)[0].tolist()[-1]
|
| 278 |
+
mask_centers = set(torch.where(masked_index==1)[0].tolist())
|
| 279 |
+
new_centers = deepcopy(mask_centers)
|
| 280 |
+
for center in mask_centers:
|
| 281 |
+
for mask_number in mask_list:
|
| 282 |
+
current_index = center + mask_number
|
| 283 |
+
if current_index <= end and current_index >= 1:
|
| 284 |
+
new_centers.add(current_index)
|
| 285 |
+
new_centers = list(new_centers)
|
| 286 |
+
masked_indices[i][new_centers] = True
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
| 290 |
+
|
| 291 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
| 292 |
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
| 293 |
+
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
|
| 294 |
+
|
| 295 |
+
# 10% of the time, we replace masked input tokens with random word
|
| 296 |
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
| 297 |
+
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
|
| 298 |
+
inputs[indices_random] = random_words[indices_random]
|
| 299 |
+
|
| 300 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
| 301 |
+
return inputs, labels
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
|
| 305 |
+
""" Train the model """
|
| 306 |
+
if args.local_rank in [-1, 0]:
|
| 307 |
+
tb_writer = SummaryWriter()
|
| 308 |
+
|
| 309 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
| 310 |
+
|
| 311 |
+
def collate(examples: List[torch.Tensor]):
|
| 312 |
+
if tokenizer._pad_token is None:
|
| 313 |
+
return pad_sequence(examples, batch_first=True)
|
| 314 |
+
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
|
| 315 |
+
|
| 316 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
| 317 |
+
train_dataloader = DataLoader(
|
| 318 |
+
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if args.max_steps > 0:
|
| 322 |
+
t_total = args.max_steps
|
| 323 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
| 324 |
+
else:
|
| 325 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
| 326 |
+
|
| 327 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
| 328 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
| 329 |
+
optimizer_grouped_parameters = [
|
| 330 |
+
{
|
| 331 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
| 332 |
+
"weight_decay": args.weight_decay,
|
| 333 |
+
},
|
| 334 |
+
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
| 335 |
+
]
|
| 336 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, betas=(args.beta1,args.beta2))
|
| 337 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 338 |
+
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Check if saved optimizer or scheduler states exist
|
| 342 |
+
if (
|
| 343 |
+
args.model_name_or_path
|
| 344 |
+
and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
|
| 345 |
+
and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
|
| 346 |
+
):
|
| 347 |
+
# Load in optimizer and scheduler states
|
| 348 |
+
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
| 349 |
+
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
| 350 |
+
|
| 351 |
+
if args.fp16:
|
| 352 |
+
try:
|
| 353 |
+
from apex import amp
|
| 354 |
+
except ImportError:
|
| 355 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
| 356 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
| 357 |
+
|
| 358 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
| 359 |
+
if args.n_gpu > 1:
|
| 360 |
+
model = torch.nn.DataParallel(model)
|
| 361 |
+
|
| 362 |
+
# Distributed training (should be after apex fp16 initialization)
|
| 363 |
+
if args.local_rank != -1:
|
| 364 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 365 |
+
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Train!
|
| 369 |
+
logger.info("***** Running training *****")
|
| 370 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
| 371 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
| 372 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
| 373 |
+
logger.info(
|
| 374 |
+
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
| 375 |
+
args.train_batch_size
|
| 376 |
+
* args.gradient_accumulation_steps
|
| 377 |
+
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
| 378 |
+
)
|
| 379 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
| 380 |
+
logger.info(" Total optimization steps = %d", t_total)
|
| 381 |
+
|
| 382 |
+
global_step = 0
|
| 383 |
+
epochs_trained = 0
|
| 384 |
+
steps_trained_in_current_epoch = 0
|
| 385 |
+
# Check if continuing training from a checkpoint
|
| 386 |
+
if args.model_name_or_path and os.path.exists(args.model_name_or_path):
|
| 387 |
+
try:
|
| 388 |
+
# set global_step to gobal_step of last saved checkpoint from model path
|
| 389 |
+
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
| 390 |
+
global_step = int(checkpoint_suffix)
|
| 391 |
+
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
| 392 |
+
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
| 393 |
+
|
| 394 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
| 395 |
+
logger.info(" Continuing training from epoch %d", epochs_trained)
|
| 396 |
+
logger.info(" Continuing training from global step %d", global_step)
|
| 397 |
+
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
| 398 |
+
except ValueError:
|
| 399 |
+
logger.info(" Starting fine-tuning.")
|
| 400 |
+
|
| 401 |
+
tr_loss, logging_loss = 0.0, 0.0
|
| 402 |
+
|
| 403 |
+
model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
|
| 404 |
+
model_to_resize.resize_token_embeddings(len(tokenizer))
|
| 405 |
+
|
| 406 |
+
model.zero_grad()
|
| 407 |
+
train_iterator = trange(
|
| 408 |
+
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
| 409 |
+
)
|
| 410 |
+
set_seed(args) # Added here for reproducibility
|
| 411 |
+
ids_set = {'0':0,'1':0,'2':0,'3':0,'4':0,'5':0,'6':0,'7':0,'8':0}
|
| 412 |
+
for _ in train_iterator:
|
| 413 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
| 414 |
+
for step, batch in enumerate(epoch_iterator):
|
| 415 |
+
|
| 416 |
+
# Skip past any already trained steps if resuming training
|
| 417 |
+
if steps_trained_in_current_epoch > 0:
|
| 418 |
+
steps_trained_in_current_epoch -= 1
|
| 419 |
+
continue
|
| 420 |
+
|
| 421 |
+
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
| 422 |
+
# print(inputs.shape)
|
| 423 |
+
# print(inputs)
|
| 424 |
+
# for i in range(len(inputs)):
|
| 425 |
+
# for j in range(len(inputs[i])):
|
| 426 |
+
# ids_set[str(int(inputs[i][j]))] += 1
|
| 427 |
+
# print(ids_set)
|
| 428 |
+
inputs = inputs.to(args.device)
|
| 429 |
+
labels = labels.to(args.device)
|
| 430 |
+
model.train()
|
| 431 |
+
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
|
| 432 |
+
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
| 433 |
+
|
| 434 |
+
if args.n_gpu > 1:
|
| 435 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
| 436 |
+
if args.gradient_accumulation_steps > 1:
|
| 437 |
+
loss = loss / args.gradient_accumulation_steps
|
| 438 |
+
|
| 439 |
+
if args.fp16:
|
| 440 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| 441 |
+
scaled_loss.backward()
|
| 442 |
+
else:
|
| 443 |
+
loss.backward()
|
| 444 |
+
|
| 445 |
+
tr_loss += loss.item()
|
| 446 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
| 447 |
+
if args.fp16:
|
| 448 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
| 449 |
+
else:
|
| 450 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 451 |
+
optimizer.step()
|
| 452 |
+
scheduler.step() # Update learning rate schedule
|
| 453 |
+
model.zero_grad()
|
| 454 |
+
global_step += 1
|
| 455 |
+
|
| 456 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
| 457 |
+
# Log metrics
|
| 458 |
+
if (
|
| 459 |
+
args.local_rank == -1 and args.evaluate_during_training
|
| 460 |
+
): # Only evaluate when single GPU otherwise metrics may not average well
|
| 461 |
+
results = evaluate(args, model, tokenizer)
|
| 462 |
+
for key, value in results.items():
|
| 463 |
+
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
|
| 464 |
+
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
|
| 465 |
+
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
|
| 466 |
+
logging_loss = tr_loss
|
| 467 |
+
|
| 468 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
| 469 |
+
checkpoint_prefix = "checkpoint"
|
| 470 |
+
# Save model checkpoint
|
| 471 |
+
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
|
| 472 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 473 |
+
model_to_save = (
|
| 474 |
+
model.module if hasattr(model, "module") else model
|
| 475 |
+
) # Take care of distributed/parallel training
|
| 476 |
+
model_to_save.save_pretrained(output_dir)
|
| 477 |
+
tokenizer.save_pretrained(output_dir)
|
| 478 |
+
|
| 479 |
+
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
| 480 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
| 481 |
+
|
| 482 |
+
_rotate_checkpoints(args, checkpoint_prefix)
|
| 483 |
+
|
| 484 |
+
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
| 485 |
+
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
| 486 |
+
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
| 487 |
+
|
| 488 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
| 489 |
+
epoch_iterator.close()
|
| 490 |
+
break
|
| 491 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
| 492 |
+
train_iterator.close()
|
| 493 |
+
break
|
| 494 |
+
|
| 495 |
+
if args.local_rank in [-1, 0]:
|
| 496 |
+
tb_writer.close()
|
| 497 |
+
|
| 498 |
+
return global_step, tr_loss / global_step
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
|
| 502 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
| 503 |
+
eval_output_dir = args.output_dir
|
| 504 |
+
|
| 505 |
+
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
| 506 |
+
|
| 507 |
+
if args.local_rank in [-1, 0]:
|
| 508 |
+
os.makedirs(eval_output_dir, exist_ok=True)
|
| 509 |
+
|
| 510 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
| 511 |
+
# Note that DistributedSampler samples randomly
|
| 512 |
+
|
| 513 |
+
def collate(examples: List[torch.Tensor]):
|
| 514 |
+
if tokenizer._pad_token is None:
|
| 515 |
+
return pad_sequence(examples, batch_first=True)
|
| 516 |
+
return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)
|
| 517 |
+
|
| 518 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
| 519 |
+
eval_dataloader = DataLoader(
|
| 520 |
+
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# multi-gpu evaluate
|
| 524 |
+
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
| 525 |
+
model = torch.nn.DataParallel(model)
|
| 526 |
+
|
| 527 |
+
# Eval!
|
| 528 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
| 529 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
| 530 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
| 531 |
+
eval_loss = 0.0
|
| 532 |
+
nb_eval_steps = 0
|
| 533 |
+
model.eval()
|
| 534 |
+
|
| 535 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
| 536 |
+
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
|
| 537 |
+
inputs = inputs.to(args.device)
|
| 538 |
+
labels = labels.to(args.device)
|
| 539 |
+
|
| 540 |
+
with torch.no_grad():
|
| 541 |
+
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
|
| 542 |
+
lm_loss = outputs[0]
|
| 543 |
+
eval_loss += lm_loss.mean().item()
|
| 544 |
+
nb_eval_steps += 1
|
| 545 |
+
|
| 546 |
+
eval_loss = eval_loss / nb_eval_steps
|
| 547 |
+
perplexity = torch.exp(torch.tensor(eval_loss))
|
| 548 |
+
|
| 549 |
+
result = {"perplexity": perplexity}
|
| 550 |
+
|
| 551 |
+
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
| 552 |
+
with open(output_eval_file, "a") as writer:
|
| 553 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
| 554 |
+
for key in sorted(result.keys()):
|
| 555 |
+
logger.info(" %s = %s", key, str(result[key]))
|
| 556 |
+
writer.write(str(float(perplexity)) + "\n")
|
| 557 |
+
# writer.write("%s = %s\n" % (key, str(result[key])))
|
| 558 |
+
|
| 559 |
+
return result
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def main():
|
| 563 |
+
parser = argparse.ArgumentParser()
|
| 564 |
+
|
| 565 |
+
# Required parameters
|
| 566 |
+
parser.add_argument(
|
| 567 |
+
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
| 568 |
+
)
|
| 569 |
+
parser.add_argument(
|
| 570 |
+
"--output_dir",
|
| 571 |
+
type=str,
|
| 572 |
+
required=True,
|
| 573 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 574 |
+
)
|
| 575 |
+
parser.add_argument(
|
| 576 |
+
"--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# Other parameters
|
| 580 |
+
parser.add_argument(
|
| 581 |
+
"--eval_data_file",
|
| 582 |
+
default=None,
|
| 583 |
+
type=str,
|
| 584 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
|
| 585 |
+
)
|
| 586 |
+
parser.add_argument(
|
| 587 |
+
"--line_by_line",
|
| 588 |
+
action="store_true",
|
| 589 |
+
help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
|
| 590 |
+
)
|
| 591 |
+
parser.add_argument(
|
| 592 |
+
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
|
| 593 |
+
)
|
| 594 |
+
parser.add_argument(
|
| 595 |
+
"--model_name_or_path",
|
| 596 |
+
default=None,
|
| 597 |
+
type=str,
|
| 598 |
+
help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
parser.add_argument(
|
| 602 |
+
"--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling."
|
| 603 |
+
)
|
| 604 |
+
parser.add_argument(
|
| 605 |
+
"--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss"
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
parser.add_argument(
|
| 609 |
+
"--config_name",
|
| 610 |
+
default=None,
|
| 611 |
+
type=str,
|
| 612 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
|
| 613 |
+
)
|
| 614 |
+
parser.add_argument(
|
| 615 |
+
"--tokenizer_name",
|
| 616 |
+
default=None,
|
| 617 |
+
type=str,
|
| 618 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
|
| 619 |
+
)
|
| 620 |
+
parser.add_argument(
|
| 621 |
+
"--cache_dir",
|
| 622 |
+
default=None,
|
| 623 |
+
type=str,
|
| 624 |
+
help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
|
| 625 |
+
)
|
| 626 |
+
parser.add_argument(
|
| 627 |
+
"--block_size",
|
| 628 |
+
default=-1,
|
| 629 |
+
type=int,
|
| 630 |
+
help="Optional input sequence length after tokenization."
|
| 631 |
+
"The training dataset will be truncated in block of this size for training."
|
| 632 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).",
|
| 633 |
+
)
|
| 634 |
+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
| 635 |
+
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
| 636 |
+
parser.add_argument(
|
| 637 |
+
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
|
| 641 |
+
parser.add_argument(
|
| 642 |
+
"--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation."
|
| 643 |
+
)
|
| 644 |
+
parser.add_argument(
|
| 645 |
+
"--gradient_accumulation_steps",
|
| 646 |
+
type=int,
|
| 647 |
+
default=1,
|
| 648 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 649 |
+
)
|
| 650 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
| 651 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
| 652 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
| 653 |
+
parser.add_argument("--beta1", default=0.9, type=float, help="Beta1 for Adam optimizer.")
|
| 654 |
+
parser.add_argument("--beta2", default=0.999, type=float, help="Beta2 for Adam optimizer.")
|
| 655 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 656 |
+
parser.add_argument(
|
| 657 |
+
"--num_train_epochs", default=1.0, type=float, help="Total number of training epochs to perform."
|
| 658 |
+
)
|
| 659 |
+
parser.add_argument(
|
| 660 |
+
"--max_steps",
|
| 661 |
+
default=-1,
|
| 662 |
+
type=int,
|
| 663 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
| 664 |
+
)
|
| 665 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
| 666 |
+
|
| 667 |
+
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
| 668 |
+
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
| 669 |
+
parser.add_argument(
|
| 670 |
+
"--save_total_limit",
|
| 671 |
+
type=int,
|
| 672 |
+
default=None,
|
| 673 |
+
help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default",
|
| 674 |
+
)
|
| 675 |
+
parser.add_argument(
|
| 676 |
+
"--eval_all_checkpoints",
|
| 677 |
+
action="store_true",
|
| 678 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number",
|
| 679 |
+
)
|
| 680 |
+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
| 681 |
+
parser.add_argument(
|
| 682 |
+
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
| 683 |
+
)
|
| 684 |
+
parser.add_argument(
|
| 685 |
+
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
| 686 |
+
)
|
| 687 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
| 688 |
+
parser.add_argument("--n_process", type=int, default=1, help="")
|
| 689 |
+
|
| 690 |
+
parser.add_argument(
|
| 691 |
+
"--fp16",
|
| 692 |
+
action="store_true",
|
| 693 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
| 694 |
+
)
|
| 695 |
+
parser.add_argument(
|
| 696 |
+
"--fp16_opt_level",
|
| 697 |
+
type=str,
|
| 698 |
+
default="O1",
|
| 699 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
| 700 |
+
"See details at https://nvidia.github.io/apex/amp.html",
|
| 701 |
+
)
|
| 702 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 703 |
+
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
| 704 |
+
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
| 705 |
+
args = parser.parse_args()
|
| 706 |
+
|
| 707 |
+
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
| 708 |
+
raise ValueError(
|
| 709 |
+
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
|
| 710 |
+
"flag (masked language modeling)."
|
| 711 |
+
)
|
| 712 |
+
if args.eval_data_file is None and args.do_eval:
|
| 713 |
+
raise ValueError(
|
| 714 |
+
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
| 715 |
+
"or remove the --do_eval argument."
|
| 716 |
+
)
|
| 717 |
+
if args.should_continue:
|
| 718 |
+
sorted_checkpoints = _sorted_checkpoints(args)
|
| 719 |
+
if len(sorted_checkpoints) == 0:
|
| 720 |
+
raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
|
| 721 |
+
else:
|
| 722 |
+
args.model_name_or_path = sorted_checkpoints[-1]
|
| 723 |
+
|
| 724 |
+
if (
|
| 725 |
+
os.path.exists(args.output_dir)
|
| 726 |
+
and os.listdir(args.output_dir)
|
| 727 |
+
and args.do_train
|
| 728 |
+
and not args.overwrite_output_dir
|
| 729 |
+
):
|
| 730 |
+
raise ValueError(
|
| 731 |
+
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
| 732 |
+
args.output_dir
|
| 733 |
+
)
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# Setup distant debugging if needed
|
| 737 |
+
if args.server_ip and args.server_port:
|
| 738 |
+
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
| 739 |
+
import ptvsd
|
| 740 |
+
|
| 741 |
+
print("Waiting for debugger attach")
|
| 742 |
+
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
| 743 |
+
ptvsd.wait_for_attach()
|
| 744 |
+
|
| 745 |
+
# Setup CUDA, GPU & distributed training
|
| 746 |
+
if args.local_rank == -1 or args.no_cuda:
|
| 747 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
| 748 |
+
args.n_gpu = torch.cuda.device_count()
|
| 749 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
| 750 |
+
torch.cuda.set_device(args.local_rank)
|
| 751 |
+
device = torch.device("cuda", args.local_rank)
|
| 752 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 753 |
+
args.n_gpu = 1
|
| 754 |
+
args.device = device
|
| 755 |
+
|
| 756 |
+
# Setup logging
|
| 757 |
+
logging.basicConfig(
|
| 758 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 759 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 760 |
+
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
| 761 |
+
)
|
| 762 |
+
logger.warning(
|
| 763 |
+
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
| 764 |
+
args.local_rank,
|
| 765 |
+
device,
|
| 766 |
+
args.n_gpu,
|
| 767 |
+
bool(args.local_rank != -1),
|
| 768 |
+
args.fp16,
|
| 769 |
+
)
|
| 770 |
+
|
| 771 |
+
# Set seed
|
| 772 |
+
set_seed(args)
|
| 773 |
+
|
| 774 |
+
# Load pretrained model and tokenizer
|
| 775 |
+
if args.local_rank not in [-1, 0]:
|
| 776 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
| 777 |
+
|
| 778 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
| 779 |
+
|
| 780 |
+
if args.config_name:
|
| 781 |
+
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
| 782 |
+
elif args.model_name_or_path:
|
| 783 |
+
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
| 784 |
+
else:
|
| 785 |
+
config = config_class()
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
if args.tokenizer_name:
|
| 789 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
| 790 |
+
elif args.model_name_or_path:
|
| 791 |
+
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
| 792 |
+
else:
|
| 793 |
+
raise ValueError(
|
| 794 |
+
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
|
| 795 |
+
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# text = "C G A T A T A G"
|
| 799 |
+
# print(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
|
| 800 |
+
|
| 801 |
+
if args.block_size <= 0:
|
| 802 |
+
args.block_size = tokenizer.max_len
|
| 803 |
+
# Our input block size will be the max possible for the model
|
| 804 |
+
else:
|
| 805 |
+
args.block_size = min(args.block_size, tokenizer.max_len)
|
| 806 |
+
|
| 807 |
+
if args.model_name_or_path:
|
| 808 |
+
model = model_class.from_pretrained(
|
| 809 |
+
args.model_name_or_path,
|
| 810 |
+
from_tf=bool(".ckpt" in args.model_name_or_path),
|
| 811 |
+
config=config,
|
| 812 |
+
cache_dir=args.cache_dir,
|
| 813 |
+
)
|
| 814 |
+
else:
|
| 815 |
+
logger.info("Training new model from scratch")
|
| 816 |
+
model = model_class(config=config)
|
| 817 |
+
|
| 818 |
+
model.to(args.device)
|
| 819 |
+
|
| 820 |
+
if args.local_rank == 0:
|
| 821 |
+
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
|
| 822 |
+
|
| 823 |
+
logger.info("Training/evaluation parameters %s", args)
|
| 824 |
+
|
| 825 |
+
# Training
|
| 826 |
+
if args.do_train:
|
| 827 |
+
if args.local_rank not in [-1, 0]:
|
| 828 |
+
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
|
| 829 |
+
|
| 830 |
+
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
|
| 831 |
+
|
| 832 |
+
if args.local_rank == 0:
|
| 833 |
+
torch.distributed.barrier()
|
| 834 |
+
|
| 835 |
+
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
| 836 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
| 837 |
+
|
| 838 |
+
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
| 839 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
| 840 |
+
# Create output directory if needed
|
| 841 |
+
if args.local_rank in [-1, 0]:
|
| 842 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 843 |
+
|
| 844 |
+
logger.info("Saving model checkpoint to %s", args.output_dir)
|
| 845 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
| 846 |
+
# They can then be reloaded using `from_pretrained()`
|
| 847 |
+
model_to_save = (
|
| 848 |
+
model.module if hasattr(model, "module") else model
|
| 849 |
+
) # Take care of distributed/parallel training
|
| 850 |
+
model_to_save.save_pretrained(args.output_dir)
|
| 851 |
+
tokenizer.save_pretrained(args.output_dir)
|
| 852 |
+
|
| 853 |
+
# Good practice: save your training arguments together with the trained model
|
| 854 |
+
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
| 855 |
+
|
| 856 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
| 857 |
+
model = model_class.from_pretrained(args.output_dir)
|
| 858 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
| 859 |
+
model.to(args.device)
|
| 860 |
+
|
| 861 |
+
# Evaluation
|
| 862 |
+
results = {}
|
| 863 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
| 864 |
+
checkpoints = [args.output_dir]
|
| 865 |
+
if args.eval_all_checkpoints:
|
| 866 |
+
checkpoints = list(
|
| 867 |
+
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
| 868 |
+
)
|
| 869 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
| 870 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
| 871 |
+
for checkpoint in checkpoints:
|
| 872 |
+
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
| 873 |
+
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
| 874 |
+
|
| 875 |
+
model = model_class.from_pretrained(checkpoint)
|
| 876 |
+
model.to(args.device)
|
| 877 |
+
result = evaluate(args, model, tokenizer, prefix=prefix)
|
| 878 |
+
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
| 879 |
+
results.update(result)
|
| 880 |
+
|
| 881 |
+
return results
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
if __name__ == "__main__":
|
| 885 |
+
main()
|
examples/run_pretrain.sh.save
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Launch with 4 processes (one for each GPU)
|
| 2 |
+
export KMER=6
|
| 3 |
+
export TRAIN_FILE=/home/n5huang/dna_token/output_tokens/all_tokenized_train.txt
|
| 4 |
+
export TEST_FILE=/home/n5huang/dna_token/output_tokens/all_tokenized_val.txt
|
| 5 |
+
export SOURCE=PATH_TO_DNABERT_REPO
|
| 6 |
+
export OUTPUT_PATH=output$KMER
|
| 7 |
+
|
| 8 |
+
python run_pretrain.py \
|
| 9 |
+
--output_dir $OUTPUT_PATH \
|
| 10 |
+
--model_type=dna \
|
| 11 |
+
--tokenizer_name=dna$KMER \
|
| 12 |
+
--config_name=$SOURCE/src/transformers/dnabert-config/bert-config-$KMER/config.json \
|
| 13 |
+
--do_train \
|
| 14 |
+
--train_data_file=$TRAIN_FILE \
|
| 15 |
+
--do_eval \
|
| 16 |
+
--eval_data_file=$TEST_FILE \
|
| 17 |
+
--mlm \
|
| 18 |
+
--gradient_accumulation_steps 7 \ # ADJUSTED for 4 GPUs: (10 * 7 * 4 = 280)
|
| 19 |
+
--per_gpu_train_batch_size 10 \
|
| 20 |
+
--per_gpu_eval_batch_size 6 \
|
| 21 |
+
--save_steps 500 \
|
| 22 |
+
--save_total_limit 20 \
|
| 23 |
+
--max_steps 10000 \ # Recommended starting point for a custom dataset
|
| 24 |
+
--evaluate_during_training \
|
| 25 |
+
--logging_steps 500 \
|
| 26 |
+
--line_by_line \
|
| 27 |
+
--learning_rate 4e-4 \
|
| 28 |
+
--block_size 512 \
|
| 29 |
+
--adam_epsilon 1e-6 \
|
| 30 |
+
--weight_decay 0.01 \
|
| 31 |
+
--beta1 0.9 \
|
| 32 |
+
--beta2 0.98 \
|
| 33 |
+
--mlm_probability 0.025 \
|
| 34 |
+
--warmup_steps 10000 \
|
| 35 |
+
--overwrite_output_dir \
|
| 36 |
+
--n_process 24
|
examples/sample_data/ft/6/dev.tsv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/sample_data/ft/6/train.tsv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4a824c48fe4b7cd1cde690882f9cd50dd628165e168453a714065d21a9c9bc7c
|
| 3 |
+
size 21847066
|
examples/sample_data/pre/6_3k.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/save_static_embeddings.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from transformers import BertModel, BertConfig, DNATokenizer, BertForMaskedLM
|
| 5 |
+
|
| 6 |
+
# --- CONFIGURATION ---
|
| 7 |
+
OUTPUT_FOLDER = "6mer_pretrain_emb_adaptive"
|
| 8 |
+
OUTPUT_FILENAME = "static_adaptive_embed.npy"
|
| 9 |
+
CHECKPOINT_PATH = "/data/n5huang/dna_token/pretrain_output_adaptive/checkpoint-10000/"
|
| 10 |
+
|
| 11 |
+
if not CHECKPOINT_PATH:
|
| 12 |
+
raise EnvironmentError("MODEL_DIR environment variable is not set.")
|
| 13 |
+
|
| 14 |
+
# --- DUMMY MODEL CLASSES (Needed for the code structure) ---
|
| 15 |
+
MODEL_CLASSES = {
|
| 16 |
+
"dna": (BertConfig, BertForMaskedLM, DNATokenizer),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# --- CUSTOM LOADING FUNCTION (Modified to return BertModel for clean embeddings) ---
|
| 20 |
+
def loadmodel(model_dir):
|
| 21 |
+
config_class, _, tokenizer_class = MODEL_CLASSES['dna']
|
| 22 |
+
|
| 23 |
+
# Load Config
|
| 24 |
+
config = config_class.from_pretrained(model_dir)
|
| 25 |
+
|
| 26 |
+
# Explicitly load the BASE BERT MODEL (BertModel) to access the embedding layer
|
| 27 |
+
model = BertModel.from_pretrained(model_dir, config=config)
|
| 28 |
+
|
| 29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
+
model.to(device)
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
# Load Tokenizer (using custom environment variables)
|
| 34 |
+
#tokenizer_class.vocab_files_names = {"vocab_file": os.getenv("VOCAB_NAME")}
|
| 35 |
+
#tokenizer_class.pretrained_vocab_files_map = {"vocab_file": {'dna': os.getenv("VOCAB_PATH")}}
|
| 36 |
+
tokenizer = tokenizer_class.from_pretrained(model_dir)
|
| 37 |
+
|
| 38 |
+
return model, tokenizer
|
| 39 |
+
|
| 40 |
+
# --- MAIN EXECUTION ---
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
# Load the model and tokenizer
|
| 43 |
+
print("Starting model and tokenizer load...")
|
| 44 |
+
model, tokenizer = loadmodel(CHECKPOINT_PATH)
|
| 45 |
+
print(f"Model and Tokenizer loaded successfully. Vocab size: {len(tokenizer)}")
|
| 46 |
+
|
| 47 |
+
# 1. Extract the static embedding layer
|
| 48 |
+
# This matrix contains the vector for every token ID (4101 tokens x 768 dimensions)
|
| 49 |
+
embedding_layer = model.get_input_embeddings()
|
| 50 |
+
print(embedding_layer.weight.shape)
|
| 51 |
+
|
| 52 |
+
# 2. Extract the weights (the actual NumPy array)
|
| 53 |
+
# Detach from GPU and convert to NumPy
|
| 54 |
+
static_embeddings_tensor = embedding_layer.weight.data.cpu()
|
| 55 |
+
static_embeddings_array = static_embeddings_tensor.numpy()
|
| 56 |
+
|
| 57 |
+
print(f"\nExtracted embedding tensor size: {static_embeddings_tensor.size()}")
|
| 58 |
+
print(f"Extracted NumPy array shape: {static_embeddings_array.shape}")
|
| 59 |
+
|
| 60 |
+
# 3. Save the Embeddings
|
| 61 |
+
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
| 62 |
+
output_path = os.path.join(OUTPUT_FOLDER, OUTPUT_FILENAME)
|
| 63 |
+
np.save(output_path, static_embeddings_array)
|
| 64 |
+
|
| 65 |
+
print(f"\n✅ Successfully saved static embeddings to: {output_path}")
|
examples/scripts/run_mut.sh
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export MODEL_PATH=/gluster/zhihan/backup/dna/690/6
|
| 3 |
+
for model in $(ls $MODEL_PATH)
|
| 4 |
+
do
|
| 5 |
+
export MODEL="$model"
|
| 6 |
+
export CHECKPOINT=$(ls $MODEL_PATH/$MODEL | head -1)
|
| 7 |
+
if [ ! -d "/gluster/zhihan/DNABERT/examples/data/ori_results/$MODEL" ]
|
| 8 |
+
then
|
| 9 |
+
python run_finetune.py \
|
| 10 |
+
--model_type dna \
|
| 11 |
+
--tokenizer_name=dna6 \
|
| 12 |
+
--model_name_or_path $MODEL_PATH/$MODEL/$CHECKPOINT \
|
| 13 |
+
--task_name dnaprom \
|
| 14 |
+
--do_predict \
|
| 15 |
+
--data_dir /gluster/zhihan/DNABERT/examples/data/ori \
|
| 16 |
+
--max_seq_length 110 \
|
| 17 |
+
--per_gpu_pred_batch_size=256 \
|
| 18 |
+
--output_dir $MODEL_PATH/$MODEL/$CHECKPOINT \
|
| 19 |
+
--predict_dir /gluster/zhihan/DNABERT/examples/data/ori_results/$MODEL \
|
| 20 |
+
--fp16 \
|
| 21 |
+
--n_process 96
|
| 22 |
+
fi
|
| 23 |
+
done
|
| 24 |
+
|
| 25 |
+
for model in $(ls $MODEL_PATH)
|
| 26 |
+
do
|
| 27 |
+
export MODEL="$model"
|
| 28 |
+
export CHECKPOINT=$(ls $MODEL_PATH/$MODEL | head -1)
|
| 29 |
+
if [ ! -d "/gluster/zhihan/DNABERT/examples/data/mut_results/$MODEL" ]
|
| 30 |
+
then
|
| 31 |
+
python run_finetune.py \
|
| 32 |
+
--model_type dna \
|
| 33 |
+
--tokenizer_name=dna6 \
|
| 34 |
+
--model_name_or_path $MODEL_PATH/$MODEL/$CHECKPOINT \
|
| 35 |
+
--task_name dnaprom \
|
| 36 |
+
--do_predict \
|
| 37 |
+
--data_dir /gluster/zhihan/DNABERT/examples/data/mut \
|
| 38 |
+
--max_seq_length 110 \
|
| 39 |
+
--per_gpu_pred_batch_size=256 \
|
| 40 |
+
--output_dir $MODEL_PATH/$MODEL/$CHECKPOINT \
|
| 41 |
+
--predict_dir /gluster/zhihan/DNABERT/examples/data/mut_results/$MODEL \
|
| 42 |
+
--fp16 \
|
| 43 |
+
--n_process 96
|
| 44 |
+
fi
|
| 45 |
+
done
|
examples/scripts/uce.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export MODEL_PATH=/home/zhihan/6
|
| 2 |
+
# for cp in $(ls $MODEL_PATH)
|
| 3 |
+
# do
|
| 4 |
+
# cd $MODEL_PATH/$cp
|
| 5 |
+
# mv checkpoin* checkpoint-0
|
| 6 |
+
# done
|
| 7 |
+
|
| 8 |
+
for model in $(ls $MODEL_PATH | head -345)
|
| 9 |
+
do
|
| 10 |
+
export MODEL="$model"
|
| 11 |
+
export CHECKPOINT=$(ls $MODEL_PATH/$MODEL)
|
| 12 |
+
CUDA_VISIBLE_DEVICES=0 python run_finetune.py \
|
| 13 |
+
--model_type dna \
|
| 14 |
+
--tokenizer_name=dna6 \
|
| 15 |
+
--model_name_or_path $MODEL_PATH/$MODEL/$CHECKPOINT \
|
| 16 |
+
--task_name dnaprom \
|
| 17 |
+
--do_visualize \
|
| 18 |
+
--visualize_data_dir /home/zhihan/data/uce/processed/ \
|
| 19 |
+
--visualize_models 6 \
|
| 20 |
+
--data_dir /home/zhihan/data/uce/processed/ \
|
| 21 |
+
--max_seq_length 110 \
|
| 22 |
+
--per_gpu_pred_batch_size=16 \
|
| 23 |
+
--output_dir $MODEL_PATH/$MODEL/$CHECKPOINT \
|
| 24 |
+
--predict_dir /home/zhihan/data/uce/results/$MODEL \
|
| 25 |
+
--n_process 24
|
| 26 |
+
done
|
examples/visualize.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import seaborn as sns
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from transformers import BertTokenizer, BertModel, DNATokenizer
|
| 9 |
+
from process_pretrain_data import get_kmer_sentence
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def format_attention(attention):
|
| 13 |
+
squeezed = []
|
| 14 |
+
for layer_attention in attention:
|
| 15 |
+
# 1 x num_heads x seq_len x seq_len
|
| 16 |
+
if len(layer_attention.shape) != 4:
|
| 17 |
+
raise ValueError("The attention tensor does not have the correct number of dimensions. Make sure you set "
|
| 18 |
+
"output_attentions=True when initializing your model.")
|
| 19 |
+
squeezed.append(layer_attention.squeeze(0))
|
| 20 |
+
# num_layers x num_heads x seq_len x seq_len
|
| 21 |
+
return torch.stack(squeezed)
|
| 22 |
+
|
| 23 |
+
def get_attention_dna(model, tokenizer, sentence_a, start, end):
|
| 24 |
+
inputs = tokenizer.encode_plus(sentence_a, sentence_b=None, return_tensors='pt', add_special_tokens=True)
|
| 25 |
+
input_ids = inputs['input_ids']
|
| 26 |
+
attention = model(input_ids)[-1]
|
| 27 |
+
input_id_list = input_ids[0].tolist() # Batch index 0
|
| 28 |
+
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
|
| 29 |
+
attn = format_attention(attention)
|
| 30 |
+
attn_score = []
|
| 31 |
+
for i in range(1, len(tokens)-1):
|
| 32 |
+
attn_score.append(float(attn[start:end+1,:,0,i].sum()))
|
| 33 |
+
return attn_score
|
| 34 |
+
|
| 35 |
+
def get_real_score(attention_scores, kmer, metric):
|
| 36 |
+
counts = np.zeros([len(attention_scores)+kmer-1])
|
| 37 |
+
real_scores = np.zeros([len(attention_scores)+kmer-1])
|
| 38 |
+
|
| 39 |
+
if metric == "mean":
|
| 40 |
+
for i, score in enumerate(attention_scores):
|
| 41 |
+
for j in range(kmer):
|
| 42 |
+
counts[i+j] += 1.0
|
| 43 |
+
real_scores[i+j] += score
|
| 44 |
+
|
| 45 |
+
real_scores = real_scores/counts
|
| 46 |
+
else:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
return real_scores
|
| 50 |
+
|
| 51 |
+
SEQUENCE = "TGCCTGGCTTTTTGTAATTTTTGAAGAGACGGGGTTTTGCCATGATG"
|
| 52 |
+
|
| 53 |
+
def Visualize(args):
|
| 54 |
+
if args.kmer == 0:
|
| 55 |
+
KMER_LIST = [3,4,5,6]
|
| 56 |
+
|
| 57 |
+
for kmer in KMER_LIST:
|
| 58 |
+
tokenizer_name = 'dna' + str(kmer)
|
| 59 |
+
model_path = os.path.join(args.model_path, str(kmer))
|
| 60 |
+
model = BertModel.from_pretrained(model_path, output_attentions=True)
|
| 61 |
+
tokenizer = DNATokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
|
| 62 |
+
raw_sentence = args.sequence if args.sequence else SEQUENCE
|
| 63 |
+
sentence_a = get_kmer_sentence(raw_sentence, kmer)
|
| 64 |
+
tokens = sentence_a.split()
|
| 65 |
+
|
| 66 |
+
attention = get_attention_dna(model, tokenizer, sentence_a, start=args.start_layer, end=args.end_layer)
|
| 67 |
+
attention_scores = np.array(attention).reshape(np.array(attention).shape[0],1)
|
| 68 |
+
# attention_scores[0] = 0
|
| 69 |
+
|
| 70 |
+
real_scores = get_real_score(attention_scores, kmer, args.metric)
|
| 71 |
+
real_scores = real_scores / np.linalg.norm(real_scores)
|
| 72 |
+
|
| 73 |
+
if kmer != KMER_LIST[0]:
|
| 74 |
+
scores += real_scores.reshape(1, real_scores.shape[0])
|
| 75 |
+
else:
|
| 76 |
+
scores = real_scores.reshape(1, real_scores.shape[0])
|
| 77 |
+
|
| 78 |
+
else:
|
| 79 |
+
# load model and calculate attention
|
| 80 |
+
tokenizer_name = 'dna' + str(args.kmer)
|
| 81 |
+
model_path = args.model_path
|
| 82 |
+
model = BertModel.from_pretrained(model_path, output_attentions=True)
|
| 83 |
+
tokenizer = DNATokenizer.from_pretrained(tokenizer_name, do_lower_case=False)
|
| 84 |
+
raw_sentence = args.sequence if args.sequence else SEQUENCE
|
| 85 |
+
sentence_a = get_kmer_sentence(raw_sentence, args.kmer)
|
| 86 |
+
tokens = sentence_a.split()
|
| 87 |
+
|
| 88 |
+
attention = get_attention_dna(model, tokenizer, sentence_a, start=args.start_layer, end=args.end_layer)
|
| 89 |
+
attention_scores = np.array(attention).reshape(np.array(attention).shape[0],1)
|
| 90 |
+
# attention_scores[0] = 0
|
| 91 |
+
|
| 92 |
+
real_scores = get_real_score(attention_scores, args.kmer, args.metric)
|
| 93 |
+
scores = real_scores.reshape(1, real_scores.shape[0])
|
| 94 |
+
|
| 95 |
+
ave = np.sum(scores)/scores.shape[1]
|
| 96 |
+
print(ave)
|
| 97 |
+
print(scores)
|
| 98 |
+
|
| 99 |
+
# plot
|
| 100 |
+
sns.set()
|
| 101 |
+
ax = sns.heatmap(scores, cmap='YlGnBu', vmin=0)
|
| 102 |
+
plt.show()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def main():
|
| 108 |
+
parser = argparse.ArgumentParser()
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--kmer",
|
| 111 |
+
default=0,
|
| 112 |
+
type=int,
|
| 113 |
+
help="K-mer",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--model_path",
|
| 117 |
+
default="/home/zhihan/dna/dna-transformers/examples/ft/690/p53-small/TAp73beta/3/",
|
| 118 |
+
type=str,
|
| 119 |
+
help="The path of the finetuned model",
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--start_layer",
|
| 123 |
+
default=11,
|
| 124 |
+
type=int,
|
| 125 |
+
help="Which layer to start",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--end_layer",
|
| 129 |
+
default=11,
|
| 130 |
+
type=int,
|
| 131 |
+
help="which layer to end",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--metric",
|
| 135 |
+
default="mean",
|
| 136 |
+
type=str,
|
| 137 |
+
help="the metric used for integrate predicted kmer result to real result",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--sequence",
|
| 141 |
+
default=None,
|
| 142 |
+
type=str,
|
| 143 |
+
help="the sequence for visualize",
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
Visualize(args)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
if __name__ == "__main__":
|
| 152 |
+
main()
|
motif/find_motifs.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#### ::: DNABERT-viz find motifs ::: ####
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import argparse
|
| 7 |
+
import motif_utils as utils
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--data_dir",
|
| 14 |
+
default=None,
|
| 15 |
+
type=str,
|
| 16 |
+
required=True,
|
| 17 |
+
help="The input data dir. Should contain the sequence+label .tsv files (or other data files) for the task.",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--predict_dir",
|
| 22 |
+
default=None,
|
| 23 |
+
type=str,
|
| 24 |
+
required=True,
|
| 25 |
+
help="Path where the attention scores were saved. Should contain both pred_results.npy and atten.npy",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--window_size",
|
| 30 |
+
default=24,
|
| 31 |
+
type=int,
|
| 32 |
+
help="Specified window size to be final motif length",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--min_len",
|
| 37 |
+
default=5,
|
| 38 |
+
type=int,
|
| 39 |
+
help="Specified minimum length threshold for contiguous region",
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--pval_cutoff",
|
| 44 |
+
default=0.005,
|
| 45 |
+
type=float,
|
| 46 |
+
help="Cutoff FDR/p-value to declare statistical significance",
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--min_n_motif",
|
| 51 |
+
default=3,
|
| 52 |
+
type=int,
|
| 53 |
+
help="Minimum instance inside motif to be filtered",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--align_all_ties",
|
| 58 |
+
action='store_true',
|
| 59 |
+
help="Whether to keep all best alignments when ties encountered",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--save_file_dir",
|
| 64 |
+
default='.',
|
| 65 |
+
type=str,
|
| 66 |
+
help="Path to save outputs",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--verbose",
|
| 71 |
+
action='store_true',
|
| 72 |
+
help="Verbosity controller",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--return_idx",
|
| 77 |
+
action='store_true',
|
| 78 |
+
help="Whether the indices of the motifs are only returned",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# TODO: add the conditions
|
| 82 |
+
args = parser.parse_args()
|
| 83 |
+
|
| 84 |
+
atten_scores = np.load(os.path.join(args.predict_dir,"atten.npy"))
|
| 85 |
+
pred = np.load(os.path.join(args.predict_dir,"pred_results.npy"))
|
| 86 |
+
dev = pd.read_csv(os.path.join(args.data_dir,"dev.tsv"),sep='\t',header=0)
|
| 87 |
+
dev.columns = ['sequence','label']
|
| 88 |
+
dev['seq'] = dev['sequence'].apply(utils.kmer2seq)
|
| 89 |
+
dev_pos = dev[dev['label'] == 1]
|
| 90 |
+
dev_neg = dev[dev['label'] == 0]
|
| 91 |
+
pos_atten_scores = atten_scores[dev_pos.index.values]
|
| 92 |
+
neg_atten_scores = atten_scores[dev_neg.index.values]
|
| 93 |
+
assert len(dev_pos) == len(pos_atten_scores)
|
| 94 |
+
|
| 95 |
+
# run motif analysis
|
| 96 |
+
merged_motif_seqs = utils.motif_analysis(dev_pos['seq'],
|
| 97 |
+
dev_neg['seq'],
|
| 98 |
+
pos_atten_scores,
|
| 99 |
+
window_size = args.window_size,
|
| 100 |
+
min_len = args.min_len,
|
| 101 |
+
pval_cutoff = args.pval_cutoff,
|
| 102 |
+
min_n_motif = args.min_n_motif,
|
| 103 |
+
align_all_ties = args.align_all_ties,
|
| 104 |
+
save_file_dir = args.save_file_dir,
|
| 105 |
+
verbose = args.verbose,
|
| 106 |
+
return_idx = args.return_idx
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
main()
|
| 111 |
+
|
| 112 |
+
|
motif/motif_utils.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#### ::: utils for DNABERT-viz motif search ::: ####
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
def kmer2seq(kmers):
|
| 8 |
+
"""
|
| 9 |
+
Convert kmers to original sequence
|
| 10 |
+
|
| 11 |
+
Arguments:
|
| 12 |
+
kmers -- str, kmers separated by space.
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
seq -- str, original sequence.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
kmers_list = kmers.split(" ")
|
| 19 |
+
bases = [kmer[0] for kmer in kmers_list[0:-1]]
|
| 20 |
+
bases.append(kmers_list[-1])
|
| 21 |
+
seq = "".join(bases)
|
| 22 |
+
assert len(seq) == len(kmers_list) + len(kmers_list[0]) - 1
|
| 23 |
+
return seq
|
| 24 |
+
|
| 25 |
+
def seq2kmer(seq, k):
|
| 26 |
+
"""
|
| 27 |
+
Convert original sequence to kmers
|
| 28 |
+
|
| 29 |
+
Arguments:
|
| 30 |
+
seq -- str, original sequence.
|
| 31 |
+
k -- int, kmer of length k specified.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
kmers -- str, kmers separated by space
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
|
| 38 |
+
kmers = " ".join(kmer)
|
| 39 |
+
return kmers
|
| 40 |
+
|
| 41 |
+
def contiguous_regions(condition, len_thres=5):
|
| 42 |
+
"""
|
| 43 |
+
Modified from and credit to: https://stackoverflow.com/a/4495197/3751373
|
| 44 |
+
Finds contiguous True regions of the boolean array "condition". Returns
|
| 45 |
+
a 2D array where the first column is the start index of the region and the
|
| 46 |
+
second column is the end index.
|
| 47 |
+
|
| 48 |
+
Arguments:
|
| 49 |
+
condition -- custom conditions to filter/select high attention
|
| 50 |
+
(list of boolean arrays)
|
| 51 |
+
|
| 52 |
+
Keyword arguments:
|
| 53 |
+
len_thres -- int, specified minimum length threshold for contiguous region
|
| 54 |
+
(default 5)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
idx -- Index of contiguous regions in sequence
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
# Find the indicies of changes in "condition"
|
| 62 |
+
d = np.diff(condition)
|
| 63 |
+
idx, = d.nonzero()
|
| 64 |
+
|
| 65 |
+
# We need to start things after the change in "condition". Therefore,
|
| 66 |
+
# we'll shift the index by 1 to the right.
|
| 67 |
+
idx += 1
|
| 68 |
+
|
| 69 |
+
if condition[0]:
|
| 70 |
+
# If the start of condition is True prepend a 0
|
| 71 |
+
idx = np.r_[0, idx]
|
| 72 |
+
|
| 73 |
+
if condition[-1]:
|
| 74 |
+
# If the end of condition is True, append the length of the array
|
| 75 |
+
idx = np.r_[idx, condition.size] # Edit
|
| 76 |
+
|
| 77 |
+
# Reshape the result into two columns
|
| 78 |
+
idx.shape = (-1,2)
|
| 79 |
+
|
| 80 |
+
# eliminate those not satisfying length of threshold
|
| 81 |
+
idx = idx[np.argwhere((idx[:,1]-idx[:,0])>=len_thres).flatten()]
|
| 82 |
+
return idx
|
| 83 |
+
|
| 84 |
+
def find_high_attention(score, min_len=5, **kwargs):
|
| 85 |
+
"""
|
| 86 |
+
With an array of attention scores as input, finds contiguous high attention
|
| 87 |
+
sub-regions indices having length greater than min_len.
|
| 88 |
+
|
| 89 |
+
Arguments:
|
| 90 |
+
score -- numpy array of attention scores for a sequence
|
| 91 |
+
|
| 92 |
+
Keyword arguments:
|
| 93 |
+
min_len -- int, specified minimum length threshold for contiguous region
|
| 94 |
+
(default 5)
|
| 95 |
+
**kwargs -- other input arguments:
|
| 96 |
+
cond -- custom conditions to filter/select high attention
|
| 97 |
+
(list of boolean arrays)
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
motif_regions -- indices of high attention regions in sequence
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
cond1 = (score > np.mean(score))
|
| 105 |
+
cond2 = (score > 10*np.min(score))
|
| 106 |
+
cond = [cond1, cond2]
|
| 107 |
+
|
| 108 |
+
cond = list(map(all, zip(*cond)))
|
| 109 |
+
|
| 110 |
+
if 'cond' in kwargs: # if input custom conditions, use them
|
| 111 |
+
cond = kwargs['cond']
|
| 112 |
+
if any(isinstance(x, list) for x in cond): # if input contains multiple conditions
|
| 113 |
+
cond = list(map(all, zip(*cond)))
|
| 114 |
+
|
| 115 |
+
cond = np.asarray(cond)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# find important contiguous region with high attention
|
| 119 |
+
motif_regions = contiguous_regions(cond,min_len)
|
| 120 |
+
|
| 121 |
+
return motif_regions
|
| 122 |
+
|
| 123 |
+
def count_motif_instances(seqs, motifs, allow_multi_match=False):
|
| 124 |
+
"""
|
| 125 |
+
Use Aho-Corasick algorithm for efficient multi-pattern matching
|
| 126 |
+
between input sequences and motif patterns to obtain counts of instances.
|
| 127 |
+
|
| 128 |
+
Arguments:
|
| 129 |
+
seqs -- list, numpy array or pandas series of DNA sequences
|
| 130 |
+
motifs -- list, numpy array or pandas series, a collection of motif patterns
|
| 131 |
+
to be matched to seqs
|
| 132 |
+
|
| 133 |
+
Keyword arguments:
|
| 134 |
+
allow_multi_match -- bool, whether to allow for counting multiple matchs (default False)
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
motif_count -- count of motif instances (int)
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
+
import ahocorasick
|
| 141 |
+
from operator import itemgetter
|
| 142 |
+
|
| 143 |
+
motif_count = {}
|
| 144 |
+
|
| 145 |
+
A = ahocorasick.Automaton()
|
| 146 |
+
for idx, key in enumerate(motifs):
|
| 147 |
+
A.add_word(key, (idx, key))
|
| 148 |
+
motif_count[key] = 0
|
| 149 |
+
A.make_automaton()
|
| 150 |
+
|
| 151 |
+
for seq in seqs:
|
| 152 |
+
matches = sorted(map(itemgetter(1), A.iter(seq)))
|
| 153 |
+
matched_seqs = []
|
| 154 |
+
for match in matches:
|
| 155 |
+
match_seq = match[1]
|
| 156 |
+
assert match_seq in motifs
|
| 157 |
+
if allow_multi_match:
|
| 158 |
+
motif_count[match_seq] += 1
|
| 159 |
+
else: # for a particular seq, count only once if multiple matches were found
|
| 160 |
+
if match_seq not in matched_seqs:
|
| 161 |
+
motif_count[match_seq] += 1
|
| 162 |
+
matched_seqs.append(match_seq)
|
| 163 |
+
|
| 164 |
+
return motif_count
|
| 165 |
+
|
| 166 |
+
def motifs_hypergeom_test(pos_seqs, neg_seqs, motifs, p_adjust = 'fdr_bh', alpha = 0.05, verbose=False,
|
| 167 |
+
allow_multi_match=False, **kwargs):
|
| 168 |
+
"""
|
| 169 |
+
Perform hypergeometric test to find significantly enriched motifs in positive sequences.
|
| 170 |
+
Returns a list of adjusted p-values.
|
| 171 |
+
|
| 172 |
+
Arguments:
|
| 173 |
+
pos_seqs -- list, numpy array or pandas series of positive DNA sequences
|
| 174 |
+
neg_seqs -- list, numpy array or pandas series of negative DNA sequences
|
| 175 |
+
motifs -- list, numpy array or pandas series, a collection of motif patterns
|
| 176 |
+
to be matched to seqs
|
| 177 |
+
|
| 178 |
+
Keyword arguments:
|
| 179 |
+
p_adjust -- method used to correct for multiple testing problem. Options are same as
|
| 180 |
+
statsmodels.stats.multitest (default 'fdr_bh')
|
| 181 |
+
alpha -- cutoff FDR/p-value to declare statistical significance (default 0.05)
|
| 182 |
+
verbose -- verbosity argument (default False)
|
| 183 |
+
allow_multi_match -- bool, whether to allow for counting multiple matchs (default False)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
pvals -- a list of p-values.
|
| 187 |
+
|
| 188 |
+
"""
|
| 189 |
+
from scipy.stats import hypergeom
|
| 190 |
+
import statsmodels.stats.multitest as multi
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
pvals = []
|
| 194 |
+
N = len(pos_seqs) + len(neg_seqs)
|
| 195 |
+
K = len(pos_seqs)
|
| 196 |
+
motif_count_all = count_motif_instances(pos_seqs+neg_seqs, motifs, allow_multi_match=allow_multi_match)
|
| 197 |
+
motif_count_pos = count_motif_instances(pos_seqs, motifs, allow_multi_match=allow_multi_match)
|
| 198 |
+
|
| 199 |
+
for motif in motifs:
|
| 200 |
+
n = motif_count_all[motif]
|
| 201 |
+
x = motif_count_pos[motif]
|
| 202 |
+
pval = hypergeom.sf(x-1, N, K, n)
|
| 203 |
+
if verbose:
|
| 204 |
+
if pval < 1e-5:
|
| 205 |
+
print("motif {}: N={}; K={}; n={}; x={}; p={}".format(motif, N, K, n, x, pval))
|
| 206 |
+
# pvals[motif] = pval
|
| 207 |
+
pvals.append(pval)
|
| 208 |
+
|
| 209 |
+
# adjust p-value
|
| 210 |
+
if p_adjust is not None:
|
| 211 |
+
pvals = list(multi.multipletests(pvals,alpha=alpha,method=p_adjust)[1])
|
| 212 |
+
return pvals
|
| 213 |
+
|
| 214 |
+
def filter_motifs(pos_seqs, neg_seqs, motifs, cutoff=0.05, return_idx=False, **kwargs):
|
| 215 |
+
"""
|
| 216 |
+
Wrapper function for returning the actual motifs that passed the hypergeometric test.
|
| 217 |
+
|
| 218 |
+
Arguments:
|
| 219 |
+
pos_seqs -- list, numpy array or pandas series of positive DNA sequences
|
| 220 |
+
neg_seqs -- list, numpy array or pandas series of negative DNA sequences
|
| 221 |
+
motifs -- list, numpy array or pandas series, a collection of motif patterns
|
| 222 |
+
to be matched to seqs
|
| 223 |
+
|
| 224 |
+
Keyword arguments:
|
| 225 |
+
cutoff -- cutoff FDR/p-value to declare statistical significance. (default 0.05)
|
| 226 |
+
return_idx -- whether the indices of the motifs are only returned. (default False)
|
| 227 |
+
**kwargs -- other input arguments
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
list of filtered motifs (or indices of the motifs)
|
| 231 |
+
|
| 232 |
+
"""
|
| 233 |
+
pvals = motifs_hypergeom_test(pos_seqs, neg_seqs, motifs, **kwargs)
|
| 234 |
+
if return_idx:
|
| 235 |
+
return [i for i, pval in enumerate(pvals) if pval < cutoff]
|
| 236 |
+
else:
|
| 237 |
+
return [motifs[i] for i, pval in enumerate(pvals) if pval < cutoff]
|
| 238 |
+
|
| 239 |
+
def merge_motifs(motif_seqs, min_len=5, align_all_ties=True, **kwargs):
|
| 240 |
+
"""
|
| 241 |
+
Function to merge similar motifs in input motif_seqs.
|
| 242 |
+
|
| 243 |
+
First sort keys of input motif_seqs based on length. For each query motif with length
|
| 244 |
+
guaranteed to >= key motif, perform pairwise alignment between them.
|
| 245 |
+
|
| 246 |
+
If can be aligned, find out best alignment among all combinations, then adjust start
|
| 247 |
+
and end position of high attention region based on left/right offsets calculated by
|
| 248 |
+
alignment of the query and key motifs.
|
| 249 |
+
|
| 250 |
+
If cannot be aligned with any existing key motifs, add to the new dict as new key motif.
|
| 251 |
+
|
| 252 |
+
Returns a new dict containing merged motifs.
|
| 253 |
+
|
| 254 |
+
Arguments:
|
| 255 |
+
motif_seqs -- nested dict, with the following structure:
|
| 256 |
+
{motif: {seq_idx: idx, atten_region_pos: (start, end)}}
|
| 257 |
+
where seq_idx indicates indices of pos_seqs containing a motif, and
|
| 258 |
+
atten_region_pos indicates where the high attention region is located.
|
| 259 |
+
|
| 260 |
+
Keyword arguments:
|
| 261 |
+
min_len -- int, specified minimum length threshold for contiguous region
|
| 262 |
+
(default 5)
|
| 263 |
+
|
| 264 |
+
align_all_ties -- bool, whether to keep all best alignments when ties encountered (default True)
|
| 265 |
+
|
| 266 |
+
**kwargs -- other input arguments, may include:
|
| 267 |
+
- cond: custom condition used to declare successful alignment.
|
| 268 |
+
default is score > max of (min_len -1) and (1/2 times min length of two motifs aligned)
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
merged_motif_seqs -- nested dict with same structure as `motif_seqs`
|
| 272 |
+
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
from Bio import Align
|
| 276 |
+
|
| 277 |
+
### TODO: modify algorithm to improve efficiency later
|
| 278 |
+
aligner = Align.PairwiseAligner()
|
| 279 |
+
aligner.internal_gap_score = -10000.0 # prohibit internal gaps
|
| 280 |
+
|
| 281 |
+
merged_motif_seqs = {}
|
| 282 |
+
for motif in sorted(motif_seqs, key=len): # query motif
|
| 283 |
+
if not merged_motif_seqs: # if empty
|
| 284 |
+
merged_motif_seqs[motif] = motif_seqs[motif] # add first one
|
| 285 |
+
else: # not empty, then compare and see if can be merged
|
| 286 |
+
# first create all alignment scores, to find out max
|
| 287 |
+
alignments = []
|
| 288 |
+
key_motifs = []
|
| 289 |
+
for key_motif in merged_motif_seqs.keys(): # key motif
|
| 290 |
+
if motif != key_motif: # do not attempt to align to self
|
| 291 |
+
# first is query, second is key within new dict
|
| 292 |
+
# first is guaranteed to be length >= second after sorting keys
|
| 293 |
+
alignment=aligner.align(motif, key_motif)[0]
|
| 294 |
+
|
| 295 |
+
# condition to declare successful alignment
|
| 296 |
+
cond = max((min_len -1), 0.5 * min(len(motif), len(key_motif)))
|
| 297 |
+
|
| 298 |
+
if 'cond' in kwargs:
|
| 299 |
+
cond = kwargs['cond'] # override
|
| 300 |
+
|
| 301 |
+
if alignment.score >= cond: # exists key that can align
|
| 302 |
+
alignments.append(alignment)
|
| 303 |
+
key_motifs.append(key_motif)
|
| 304 |
+
|
| 305 |
+
if alignments: # if aligned, find out alignment with maximum score and proceed
|
| 306 |
+
best_score = max(alignments, key=lambda alignment: alignment.score)
|
| 307 |
+
best_idx = [i for i, score in enumerate(alignments) if score == best_score]
|
| 308 |
+
|
| 309 |
+
if align_all_ties:
|
| 310 |
+
for i in best_idx:
|
| 311 |
+
alignment = alignments[i]
|
| 312 |
+
key_motif = key_motifs[i]
|
| 313 |
+
|
| 314 |
+
# calculate offset to be added/subtracted from atten_region_pos
|
| 315 |
+
left_offset = alignment.aligned[0][0][0] - alignment.aligned[1][0][0] # always query - key
|
| 316 |
+
if (alignment.aligned[0][0][1] <= len(motif)) & \
|
| 317 |
+
(alignment.aligned[1][0][1] == len(key_motif)): # inside
|
| 318 |
+
right_offset = len(motif) - alignment.aligned[0][0][1]
|
| 319 |
+
elif (alignment.aligned[0][0][1] == len(motif)) & \
|
| 320 |
+
(alignment.aligned[1][0][1] < len(key_motif)): # left shift
|
| 321 |
+
right_offset = alignment.aligned[1][0][1] - len(key_motif)
|
| 322 |
+
elif (alignment.aligned[0][0][1] < len(motif)) & \
|
| 323 |
+
(alignment.aligned[1][0][1] == len(key_motif)): # right shift
|
| 324 |
+
right_offset = len(motif) - alignment.aligned[0][0][1]
|
| 325 |
+
|
| 326 |
+
# add seq_idx back to new merged dict
|
| 327 |
+
merged_motif_seqs[key_motif]['seq_idx'].extend(motif_seqs[motif]['seq_idx'])
|
| 328 |
+
|
| 329 |
+
# calculate new atten_region_pos after adding/subtracting offset
|
| 330 |
+
new_atten_region_pos = [(pos[0]+left_offset, pos[1]-right_offset) \
|
| 331 |
+
for pos in motif_seqs[motif]['atten_region_pos']]
|
| 332 |
+
merged_motif_seqs[key_motif]['atten_region_pos'].extend(new_atten_region_pos)
|
| 333 |
+
|
| 334 |
+
else:
|
| 335 |
+
alignment = alignments[best_idx[0]]
|
| 336 |
+
key_motif = key_motifs[best_idx[0]]
|
| 337 |
+
|
| 338 |
+
# calculate offset to be added/subtracted from atten_region_pos
|
| 339 |
+
left_offset = alignment.aligned[0][0][0] - alignment.aligned[1][0][0] # always query - key
|
| 340 |
+
if (alignment.aligned[0][0][1] <= len(motif)) & \
|
| 341 |
+
(alignment.aligned[1][0][1] == len(key_motif)): # inside
|
| 342 |
+
right_offset = len(motif) - alignment.aligned[0][0][1]
|
| 343 |
+
elif (alignment.aligned[0][0][1] == len(motif)) & \
|
| 344 |
+
(alignment.aligned[1][0][1] < len(key_motif)): # left shift
|
| 345 |
+
right_offset = alignment.aligned[1][0][1] - len(key_motif)
|
| 346 |
+
elif (alignment.aligned[0][0][1] < len(motif)) & \
|
| 347 |
+
(alignment.aligned[1][0][1] == len(key_motif)): # right shift
|
| 348 |
+
right_offset = len(motif) - alignment.aligned[0][0][1]
|
| 349 |
+
|
| 350 |
+
# add seq_idx back to new merged dict
|
| 351 |
+
merged_motif_seqs[key_motif]['seq_idx'].extend(motif_seqs[motif]['seq_idx'])
|
| 352 |
+
|
| 353 |
+
# calculate new atten_region_pos after adding/subtracting offset
|
| 354 |
+
new_atten_region_pos = [(pos[0]+left_offset, pos[1]-right_offset) \
|
| 355 |
+
for pos in motif_seqs[motif]['atten_region_pos']]
|
| 356 |
+
merged_motif_seqs[key_motif]['atten_region_pos'].extend(new_atten_region_pos)
|
| 357 |
+
|
| 358 |
+
else: # cannot align to anything, add to new dict as independent key
|
| 359 |
+
merged_motif_seqs[motif] = motif_seqs[motif] # add new one
|
| 360 |
+
|
| 361 |
+
return merged_motif_seqs
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def make_window(motif_seqs, pos_seqs, window_size=24):
|
| 365 |
+
"""
|
| 366 |
+
Function to extract fixed, equal length sequences centered at high-attention motif instance.
|
| 367 |
+
|
| 368 |
+
Returns new dict containing seqs with fixed window_size.
|
| 369 |
+
|
| 370 |
+
Arguments:
|
| 371 |
+
motif_seqs -- nested dict, with the following structure:
|
| 372 |
+
{motif: {seq_idx: idx, atten_region_pos: (start, end)}}
|
| 373 |
+
where seq_idx indicates indices of pos_seqs containing a motif, and
|
| 374 |
+
atten_region_pos indicates where the high attention region is located.
|
| 375 |
+
pos_seqs -- list, numpy array or pandas series of positive DNA sequences
|
| 376 |
+
|
| 377 |
+
Keyword arguments:
|
| 378 |
+
window_size -- int, specified window size to be final motif length
|
| 379 |
+
(default 24)
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
new_motif_seqs -- nested dict with same structure as `motif_seqs`s
|
| 383 |
+
|
| 384 |
+
"""
|
| 385 |
+
new_motif_seqs = {}
|
| 386 |
+
|
| 387 |
+
# extract fixed-length sequences based on window_size
|
| 388 |
+
for motif, instances in motif_seqs.items():
|
| 389 |
+
new_motif_seqs[motif] = {'seq_idx':[], 'atten_region_pos':[], 'seqs': []}
|
| 390 |
+
for i, coord in enumerate(instances['atten_region_pos']):
|
| 391 |
+
atten_len = coord[1] - coord[0]
|
| 392 |
+
if (window_size - atten_len) % 2 == 0: # even
|
| 393 |
+
offset = (window_size - atten_len) / 2
|
| 394 |
+
new_coord = (int(coord[0] - offset), int(coord[1] + offset))
|
| 395 |
+
if (new_coord[0] >=0) & (new_coord[1] < len(pos_seqs[instances['seq_idx'][i]])):
|
| 396 |
+
# append
|
| 397 |
+
new_motif_seqs[motif]['seq_idx'].append(instances['seq_idx'][i])
|
| 398 |
+
new_motif_seqs[motif]['atten_region_pos'].append((new_coord[0], new_coord[1]))
|
| 399 |
+
new_motif_seqs[motif]['seqs'].append(pos_seqs[instances['seq_idx'][i]][new_coord[0]:new_coord[1]])
|
| 400 |
+
else: # odd
|
| 401 |
+
offset1 = (window_size - atten_len) // 2
|
| 402 |
+
offset2 = (window_size - atten_len) // 2 + 1
|
| 403 |
+
new_coord = (int(coord[0] - offset1), int(coord[1] + offset2))
|
| 404 |
+
if (new_coord[0] >=0) & (new_coord[1] < len(pos_seqs[instances['seq_idx'][i]])):
|
| 405 |
+
# append
|
| 406 |
+
new_motif_seqs[motif]['seq_idx'].append(instances['seq_idx'][i])
|
| 407 |
+
new_motif_seqs[motif]['atten_region_pos'].append((new_coord[0], new_coord[1]))
|
| 408 |
+
new_motif_seqs[motif]['seqs'].append(pos_seqs[instances['seq_idx'][i]][new_coord[0]:new_coord[1]])
|
| 409 |
+
|
| 410 |
+
return new_motif_seqs
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
### make full pipeline
|
| 414 |
+
def motif_analysis(pos_seqs,
|
| 415 |
+
neg_seqs,
|
| 416 |
+
pos_atten_scores,
|
| 417 |
+
window_size = 24,
|
| 418 |
+
min_len = 4,
|
| 419 |
+
pval_cutoff = 0.005,
|
| 420 |
+
min_n_motif = 3,
|
| 421 |
+
align_all_ties = True,
|
| 422 |
+
save_file_dir = None,
|
| 423 |
+
**kwargs
|
| 424 |
+
):
|
| 425 |
+
|
| 426 |
+
"""
|
| 427 |
+
Wrapper function of full motif analysis tool based on DNABERT-viz.
|
| 428 |
+
|
| 429 |
+
Arguments:
|
| 430 |
+
pos_seqs -- list, numpy array or pandas series of positive DNA sequences
|
| 431 |
+
neg_seqs -- list, numpy array or pandas series of negative DNA sequences
|
| 432 |
+
pos_atten_scores -- numpy array of attention scores for postive DNA sequence
|
| 433 |
+
|
| 434 |
+
Keyword arguments:
|
| 435 |
+
window_size -- int, specified window size to be final motif length
|
| 436 |
+
(default 24)
|
| 437 |
+
min_len -- int, specified minimum length threshold for contiguous region
|
| 438 |
+
(default 5)
|
| 439 |
+
pval_cutoff -- float, cutoff FDR/p-value to declare statistical significance. (default 0.005)
|
| 440 |
+
min_n_motif -- int, minimum instance inside motif to be filtered (default 3)
|
| 441 |
+
align_all_ties -- bool, whether to keep all best alignments when ties encountered (default True)
|
| 442 |
+
save_file_dir -- str, path to save outputs (default None)
|
| 443 |
+
**kwargs -- other input arguments, may include:
|
| 444 |
+
- verbose: bool, verbosity controller
|
| 445 |
+
- atten_cond: custom conditions to filter/select high attention
|
| 446 |
+
(list of boolean arrays)
|
| 447 |
+
- return_idx: whether the indices of the motifs are only returned.
|
| 448 |
+
- align_cond: custom condition used to declare successful alignment.
|
| 449 |
+
default is score > max of (min_len -1) and (1/2 times min length of two motifs aligned)
|
| 450 |
+
|
| 451 |
+
Returns:
|
| 452 |
+
merged_motif_seqs -- nested dict, with the following structure:
|
| 453 |
+
{motif: {seq_idx: idx, atten_region_pos: (start, end)}}
|
| 454 |
+
where seq_idx indicates indices of pos_seqs containing a motif, and
|
| 455 |
+
atten_region_pos indicates where the high attention region is located.
|
| 456 |
+
|
| 457 |
+
"""
|
| 458 |
+
from Bio import motifs
|
| 459 |
+
from Bio.Seq import Seq
|
| 460 |
+
|
| 461 |
+
verbose = False
|
| 462 |
+
if 'verbose' in kwargs:
|
| 463 |
+
verbose = kwargs['verbose']
|
| 464 |
+
|
| 465 |
+
if verbose:
|
| 466 |
+
print("*** Begin motif analysis ***")
|
| 467 |
+
pos_seqs = list(pos_seqs)
|
| 468 |
+
neg_seqs = list(neg_seqs)
|
| 469 |
+
|
| 470 |
+
if verbose:
|
| 471 |
+
print("* pos_seqs: {}; neg_seqs: {}".format(len(pos_seqs),len(neg_seqs)))
|
| 472 |
+
|
| 473 |
+
assert len(pos_seqs) == len(pos_atten_scores)
|
| 474 |
+
|
| 475 |
+
max_seq_len = len(max(pos_seqs, key=len))
|
| 476 |
+
motif_seqs = {}
|
| 477 |
+
|
| 478 |
+
## find the motif regions
|
| 479 |
+
if verbose:
|
| 480 |
+
print("* Finding high attention motif regions")
|
| 481 |
+
for i, score in enumerate(pos_atten_scores):
|
| 482 |
+
seq_len = len(pos_seqs[i])
|
| 483 |
+
score = score[0:seq_len]
|
| 484 |
+
|
| 485 |
+
# handle kwargs
|
| 486 |
+
if 'atten_cond' in kwargs:
|
| 487 |
+
motif_regions = find_high_attention(score, min_len=min_len, cond=kwargs['atten_cond'])
|
| 488 |
+
else:
|
| 489 |
+
motif_regions = find_high_attention(score, min_len=min_len)
|
| 490 |
+
|
| 491 |
+
for motif_idx in motif_regions:
|
| 492 |
+
seq = pos_seqs[i][motif_idx[0]:motif_idx[1]]
|
| 493 |
+
if seq not in motif_seqs:
|
| 494 |
+
motif_seqs[seq] = {'seq_idx': [i], 'atten_region_pos':[(motif_idx[0],motif_idx[1])]}
|
| 495 |
+
else:
|
| 496 |
+
motif_seqs[seq]['seq_idx'].append(i)
|
| 497 |
+
motif_seqs[seq]['atten_region_pos'].append((motif_idx[0],motif_idx[1]))
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# filter motifs
|
| 501 |
+
return_idx = False
|
| 502 |
+
if 'return_idx' in kwargs:
|
| 503 |
+
return_idx = kwargs['return_idx']
|
| 504 |
+
kwargs.pop('return_idx')
|
| 505 |
+
|
| 506 |
+
if verbose:
|
| 507 |
+
print("* Filtering motifs by hypergeometric test")
|
| 508 |
+
motifs_to_keep = filter_motifs(pos_seqs,
|
| 509 |
+
neg_seqs,
|
| 510 |
+
list(motif_seqs.keys()),
|
| 511 |
+
cutoff = pval_cutoff,
|
| 512 |
+
return_idx=return_idx,
|
| 513 |
+
**kwargs)
|
| 514 |
+
|
| 515 |
+
motif_seqs = {k: motif_seqs[k] for k in motifs_to_keep}
|
| 516 |
+
|
| 517 |
+
# merge motifs
|
| 518 |
+
if verbose:
|
| 519 |
+
print("* Merging similar motif instances")
|
| 520 |
+
if 'align_cond' in kwargs:
|
| 521 |
+
merged_motif_seqs = merge_motifs(motif_seqs, min_len=min_len,
|
| 522 |
+
align_all_ties = align_all_ties,
|
| 523 |
+
cond=kwargs['align_cond'])
|
| 524 |
+
else:
|
| 525 |
+
merged_motif_seqs = merge_motifs(motif_seqs, min_len=min_len,
|
| 526 |
+
align_all_ties = align_all_ties)
|
| 527 |
+
|
| 528 |
+
# make fixed-length window sequences
|
| 529 |
+
if verbose:
|
| 530 |
+
print("* Making fixed_length window = {}".format(window_size))
|
| 531 |
+
merged_motif_seqs = make_window(merged_motif_seqs, pos_seqs, window_size=window_size)
|
| 532 |
+
|
| 533 |
+
# remove motifs with only few instances
|
| 534 |
+
if verbose:
|
| 535 |
+
print("* Removing motifs with less than {} instances".format(min_n_motif))
|
| 536 |
+
merged_motif_seqs = {k: coords for k, coords in merged_motif_seqs.items() if len(coords['seq_idx']) >= min_n_motif}
|
| 537 |
+
|
| 538 |
+
if save_file_dir is not None:
|
| 539 |
+
if verbose:
|
| 540 |
+
print("* Saving outputs to directory")
|
| 541 |
+
os.makedirs(save_file_dir, exist_ok=True)
|
| 542 |
+
for motif, instances in merged_motif_seqs.items():
|
| 543 |
+
# saving to files
|
| 544 |
+
with open(save_file_dir+'/motif_{}_{}.txt'.format(motif, len(instances['seq_idx'])), 'w') as f:
|
| 545 |
+
for seq in instances['seqs']:
|
| 546 |
+
f.write(seq+'\n')
|
| 547 |
+
# make weblogo
|
| 548 |
+
seqs = [Seq(v) for i,v in enumerate(instances['seqs'])]
|
| 549 |
+
m = motifs.create(seqs)
|
| 550 |
+
m.weblogo(save_file_dir+"/motif_{}_{}_weblogo.png".format(motif, len(instances['seq_idx'])), format='png_print',
|
| 551 |
+
show_fineprint=False, show_ends=False, color_scheme='color_classic')
|
| 552 |
+
|
| 553 |
+
return merged_motif_seqs
|
save2cache.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
import random
|
| 7 |
+
import re
|
| 8 |
+
import shutil
|
| 9 |
+
from typing import Dict, List, Tuple
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from multiprocessing import Pool
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 16 |
+
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
from tqdm import tqdm, trange
|
| 19 |
+
import itertools
|
| 20 |
+
|
| 21 |
+
from transformers import (
|
| 22 |
+
WEIGHTS_NAME,
|
| 23 |
+
AdamW,
|
| 24 |
+
BertConfig,
|
| 25 |
+
BertForMaskedLM,
|
| 26 |
+
BertTokenizer,
|
| 27 |
+
DNATokenizer,
|
| 28 |
+
#myTokenizer,
|
| 29 |
+
#MotifTokenizer,
|
| 30 |
+
CamembertConfig,
|
| 31 |
+
CamembertForMaskedLM,
|
| 32 |
+
CamembertTokenizer,
|
| 33 |
+
DistilBertConfig,
|
| 34 |
+
DistilBertForMaskedLM,
|
| 35 |
+
DistilBertTokenizer,
|
| 36 |
+
GPT2Config,
|
| 37 |
+
GPT2LMHeadModel,
|
| 38 |
+
GPT2Tokenizer,
|
| 39 |
+
OpenAIGPTConfig,
|
| 40 |
+
OpenAIGPTLMHeadModel,
|
| 41 |
+
OpenAIGPTTokenizer,
|
| 42 |
+
PreTrainedModel,
|
| 43 |
+
PreTrainedTokenizer,
|
| 44 |
+
RobertaConfig,
|
| 45 |
+
RobertaForMaskedLM,
|
| 46 |
+
RobertaTokenizer,
|
| 47 |
+
get_linear_schedule_with_warmup,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 53 |
+
except ImportError:
|
| 54 |
+
from tensorboardX import SummaryWriter
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
MODEL_CLASSES = {
|
| 58 |
+
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
| 59 |
+
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
| 60 |
+
"dna": (BertConfig, BertForMaskedLM, DNATokenizer),
|
| 61 |
+
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
| 62 |
+
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
| 63 |
+
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
| 64 |
+
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
| 65 |
+
#"myBert": (BertConfig, BertForMaskedLM, myTokenizer),
|
| 66 |
+
#"motifBert": (BertConfig, BertForMaskedLM, MotifTokenizer)
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
def convert_line_to_example(tokenizer, lines, max_length, add_special_tokens=True):
|
| 70 |
+
examples = tokenizer.batch_encode_plus(lines, add_special_tokens=add_special_tokens, max_length=max_length)["input_ids"]
|
| 71 |
+
return examples
|
| 72 |
+
|
| 73 |
+
class LineByLineTextDataset(Dataset):
|
| 74 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path: str, block_size=512):
|
| 75 |
+
assert os.path.isfile(file_path)
|
| 76 |
+
# Here, we do not cache the features, operating under the assumption
|
| 77 |
+
# that we will soon use fast multithreaded tokenizers from the
|
| 78 |
+
# `tokenizers` repo everywhere =)
|
| 79 |
+
directory, filename = os.path.split(file_path)
|
| 80 |
+
cached_features_file = os.path.join(
|
| 81 |
+
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
print("Creating features from dataset file at %s", file_path)
|
| 85 |
+
|
| 86 |
+
with open(file_path, encoding="utf-8") as f:
|
| 87 |
+
lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
|
| 88 |
+
|
| 89 |
+
if args.n_process == 1:
|
| 90 |
+
self.examples = tokenizer.batch_encode_plus(lines, add_special_tokens=True, max_length=block_size)["input_ids"]
|
| 91 |
+
else:
|
| 92 |
+
n_proc = args.n_process
|
| 93 |
+
p = Pool(n_proc)
|
| 94 |
+
indexes = [0]
|
| 95 |
+
len_slice = int(len(lines)/n_proc)
|
| 96 |
+
for i in range(1, n_proc+1):
|
| 97 |
+
if i != n_proc:
|
| 98 |
+
indexes.append(len_slice*(i))
|
| 99 |
+
else:
|
| 100 |
+
indexes.append(len(lines))
|
| 101 |
+
results = []
|
| 102 |
+
for i in range(n_proc):
|
| 103 |
+
results.append(p.apply_async(convert_line_to_example,[tokenizer, lines[indexes[i]:indexes[i+1]], block_size,]))
|
| 104 |
+
print(str(i) + " start")
|
| 105 |
+
p.close()
|
| 106 |
+
p.join()
|
| 107 |
+
|
| 108 |
+
self.examples = []
|
| 109 |
+
for result in results:
|
| 110 |
+
ids = result.get()
|
| 111 |
+
self.examples.extend(ids)
|
| 112 |
+
print("Saving features into cached file %s", cached_features_file)
|
| 113 |
+
with open(cached_features_file, "wb") as handle:
|
| 114 |
+
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
| 115 |
+
|
| 116 |
+
def __len__(self):
|
| 117 |
+
return len(self.examples)
|
| 118 |
+
|
| 119 |
+
def __getitem__(self, i):
|
| 120 |
+
return torch.tensor(self.examples[i], dtype=torch.long)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
| 124 |
+
file_path = args.eval_data_file if evaluate else args.train_data_file
|
| 125 |
+
print(file_path)
|
| 126 |
+
if args.line_by_line:
|
| 127 |
+
return LineByLineTextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
|
| 128 |
+
else:
|
| 129 |
+
return TextDataset(tokenizer, args, file_path=file_path, block_size=args.block_size)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
|
| 134 |
+
if args.eval_data_file:
|
| 135 |
+
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
| 136 |
+
print('done')
|
| 137 |
+
|
| 138 |
+
if args.train_data_file:
|
| 139 |
+
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == '__main__':
|
| 143 |
+
|
| 144 |
+
parser = argparse.ArgumentParser()
|
| 145 |
+
|
| 146 |
+
# Required parameters
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)."
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Other parameters
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--eval_data_file",
|
| 154 |
+
default=None,
|
| 155 |
+
type=str,
|
| 156 |
+
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--line_by_line",
|
| 160 |
+
action="store_true",
|
| 161 |
+
help="Whether distinct lines of text in the dataset are to be handled as distinct sequences.",
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--tokenizer_name",
|
| 170 |
+
default=None,
|
| 171 |
+
type=str,
|
| 172 |
+
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--config_name",
|
| 177 |
+
default=None,
|
| 178 |
+
type=str,
|
| 179 |
+
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
parser.add_argument(
|
| 183 |
+
"--block_size",
|
| 184 |
+
default=-1,
|
| 185 |
+
type=int,
|
| 186 |
+
help="Optional input sequence length after tokenization."
|
| 187 |
+
"The training dataset will be truncated in block of this size for training."
|
| 188 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument(
|
| 191 |
+
"--specialpath",
|
| 192 |
+
type=str,
|
| 193 |
+
help="Optional input sequence length after tokenization."
|
| 194 |
+
"The training dataset will be truncated in block of this size for training."
|
| 195 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens).",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
parser.add_argument("--n_process", type=int, default=1, help="")
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
| 203 |
+
|
| 204 |
+
if args.config_name:
|
| 205 |
+
config = config_class.from_pretrained(args.config_name, cache_dir=None)
|
| 206 |
+
else:
|
| 207 |
+
config = config_class()
|
| 208 |
+
|
| 209 |
+
if args.tokenizer_name:
|
| 210 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=None)
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(
|
| 213 |
+
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
|
| 214 |
+
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if args.block_size <= 0:
|
| 218 |
+
args.block_size = tokenizer.max_len
|
| 219 |
+
# Our input block size will be the max possible for the model
|
| 220 |
+
else:
|
| 221 |
+
args.block_size = min(args.block_size, tokenizer.max_len)
|
| 222 |
+
|
| 223 |
+
main()
|
| 224 |
+
|
setup.cfg
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[isort]
|
| 2 |
+
ensure_newline_before_comments = True
|
| 3 |
+
force_grid_wrap = 0
|
| 4 |
+
include_trailing_comma = True
|
| 5 |
+
known_first_party = transformers
|
| 6 |
+
known_third_party =
|
| 7 |
+
absl
|
| 8 |
+
fairseq
|
| 9 |
+
fastprogress
|
| 10 |
+
git
|
| 11 |
+
h5py
|
| 12 |
+
MeCab
|
| 13 |
+
nltk
|
| 14 |
+
numpy
|
| 15 |
+
packaging
|
| 16 |
+
PIL
|
| 17 |
+
psutil
|
| 18 |
+
pytorch_lightning
|
| 19 |
+
seqeval
|
| 20 |
+
sklearn
|
| 21 |
+
tensorboardX
|
| 22 |
+
tensorflow
|
| 23 |
+
tensorflow_datasets
|
| 24 |
+
torch
|
| 25 |
+
torchtext
|
| 26 |
+
torchvision
|
| 27 |
+
torch_xla
|
| 28 |
+
|
| 29 |
+
line_length = 119
|
| 30 |
+
lines_after_imports = 2
|
| 31 |
+
multi_line_output = 3
|
| 32 |
+
use_parentheses = True
|
| 33 |
+
|
| 34 |
+
[flake8]
|
| 35 |
+
ignore = E203, E501, W503
|
| 36 |
+
max-line-length = 119
|
setup.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
|
| 3 |
+
|
| 4 |
+
To create the package for pypi.
|
| 5 |
+
|
| 6 |
+
1. Change the version in __init__.py, setup.py as well as docs/source/conf.py.
|
| 7 |
+
|
| 8 |
+
2. Commit these changes with the message: "Release: VERSION"
|
| 9 |
+
|
| 10 |
+
3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
|
| 11 |
+
Push the tag to git: git push --tags origin master
|
| 12 |
+
|
| 13 |
+
4. Build both the sources and the wheel. Do not change anything in setup.py between
|
| 14 |
+
creating the wheel and the source distribution (obviously).
|
| 15 |
+
|
| 16 |
+
For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
|
| 17 |
+
(this will build a wheel for the python version you use to build it).
|
| 18 |
+
|
| 19 |
+
For the sources, run: "python setup.py sdist"
|
| 20 |
+
You should now have a /dist directory with both .whl and .tar.gz source versions.
|
| 21 |
+
|
| 22 |
+
5. Check that everything looks correct by uploading the package to the pypi test server:
|
| 23 |
+
|
| 24 |
+
twine upload dist/* -r pypitest
|
| 25 |
+
(pypi suggest using twine as other methods upload files via plaintext.)
|
| 26 |
+
You may have to specify the repository url, use the following command then:
|
| 27 |
+
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
| 28 |
+
|
| 29 |
+
Check that you can install it in a virtualenv by running:
|
| 30 |
+
pip install -i https://testpypi.python.org/pypi transformers
|
| 31 |
+
|
| 32 |
+
6. Upload the final version to actual pypi:
|
| 33 |
+
twine upload dist/* -r pypi
|
| 34 |
+
|
| 35 |
+
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
| 36 |
+
|
| 37 |
+
8. Update the documentation commit in .circleci/deploy.sh for the accurate documentation to be displayed
|
| 38 |
+
|
| 39 |
+
9. Update README.md to redirect to correct documentation.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
import shutil
|
| 43 |
+
from pathlib import Path
|
| 44 |
+
|
| 45 |
+
from setuptools import find_packages, setup
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Remove stale transformers.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
|
| 49 |
+
stale_egg_info = Path(__file__).parent / "transformers.egg-info"
|
| 50 |
+
if stale_egg_info.exists():
|
| 51 |
+
print(
|
| 52 |
+
(
|
| 53 |
+
"Warning: {} exists.\n\n"
|
| 54 |
+
"If you recently updated transformers to 3.0 or later, this is expected,\n"
|
| 55 |
+
"but it may prevent transformers from installing in editable mode.\n\n"
|
| 56 |
+
"This directory is automatically generated by Python's packaging tools.\n"
|
| 57 |
+
"I will remove it now.\n\n"
|
| 58 |
+
"See https://github.com/pypa/pip/issues/5466 for details.\n"
|
| 59 |
+
).format(stale_egg_info)
|
| 60 |
+
)
|
| 61 |
+
shutil.rmtree(stale_egg_info)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
extras = {}
|
| 65 |
+
|
| 66 |
+
extras["mecab"] = ["mecab-python3"]
|
| 67 |
+
extras["sklearn"] = ["scikit-learn"]
|
| 68 |
+
extras["tf"] = ["tensorflow"]
|
| 69 |
+
extras["tf-cpu"] = ["tensorflow-cpu"]
|
| 70 |
+
extras["torch"] = ["torch"]
|
| 71 |
+
|
| 72 |
+
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
| 73 |
+
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
|
| 74 |
+
|
| 75 |
+
extras["testing"] = ["pytest", "pytest-xdist"]
|
| 76 |
+
extras["quality"] = ["black", "isort", "flake8"]
|
| 77 |
+
extras["docs"] = ["recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme"]
|
| 78 |
+
extras["dev"] = extras["testing"] + extras["quality"] + ["mecab-python3", "scikit-learn", "tensorflow", "torch"]
|
| 79 |
+
|
| 80 |
+
setup(
|
| 81 |
+
name="transformers",
|
| 82 |
+
version="2.5.0",
|
| 83 |
+
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
| 84 |
+
author_email="thomas@huggingface.co",
|
| 85 |
+
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
| 86 |
+
long_description=open("README.md", "r", encoding="utf-8").read(),
|
| 87 |
+
long_description_content_type="text/markdown",
|
| 88 |
+
keywords="NLP deep learning transformer pytorch tensorflow BERT GPT GPT-2 google openai CMU",
|
| 89 |
+
license="Apache",
|
| 90 |
+
url="https://github.com/huggingface/transformers",
|
| 91 |
+
package_dir={"": "src"},
|
| 92 |
+
packages=find_packages("src"),
|
| 93 |
+
install_requires=[
|
| 94 |
+
"numpy",
|
| 95 |
+
"tokenizers == 0.5.0",
|
| 96 |
+
# accessing files from S3 directly
|
| 97 |
+
"boto3",
|
| 98 |
+
# filesystem locks e.g. to prevent parallel downloads
|
| 99 |
+
"filelock",
|
| 100 |
+
# for downloading models over HTTPS
|
| 101 |
+
"requests",
|
| 102 |
+
# progress bars in model download and training scripts
|
| 103 |
+
"tqdm >= 4.27",
|
| 104 |
+
# for OpenAI GPT
|
| 105 |
+
"regex != 2019.12.17",
|
| 106 |
+
# for XLNet
|
| 107 |
+
"sentencepiece",
|
| 108 |
+
# for XLM
|
| 109 |
+
"sacremoses",
|
| 110 |
+
],
|
| 111 |
+
extras_require=extras,
|
| 112 |
+
scripts=["transformers-cli"],
|
| 113 |
+
python_requires=">=3.5.0",
|
| 114 |
+
classifiers=[
|
| 115 |
+
"Development Status :: 5 - Production/Stable",
|
| 116 |
+
"Intended Audience :: Developers",
|
| 117 |
+
"Intended Audience :: Education",
|
| 118 |
+
"Intended Audience :: Science/Research",
|
| 119 |
+
"License :: OSI Approved :: Apache Software License",
|
| 120 |
+
"Operating System :: OS Independent",
|
| 121 |
+
"Programming Language :: Python :: 3",
|
| 122 |
+
"Programming Language :: Python :: 3.5",
|
| 123 |
+
"Programming Language :: Python :: 3.6",
|
| 124 |
+
"Programming Language :: Python :: 3.7",
|
| 125 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 126 |
+
],
|
| 127 |
+
)
|
src/transformers/__init__.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
| 3 |
+
# module, but to preserve other warnings. So, don't check this module at all.
|
| 4 |
+
|
| 5 |
+
__version__ = "2.5.0"
|
| 6 |
+
|
| 7 |
+
# Work around to update TensorFlow's absl.logging threshold which alters the
|
| 8 |
+
# default Python logging output behavior when present.
|
| 9 |
+
# see: https://github.com/abseil/abseil-py/issues/99
|
| 10 |
+
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
|
| 11 |
+
try:
|
| 12 |
+
import absl.logging
|
| 13 |
+
except ImportError:
|
| 14 |
+
pass
|
| 15 |
+
else:
|
| 16 |
+
absl.logging.set_verbosity("info")
|
| 17 |
+
absl.logging.set_stderrthreshold("info")
|
| 18 |
+
absl.logging._warn_preinit_stderr = False
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
|
| 22 |
+
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
| 23 |
+
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
|
| 24 |
+
from .configuration_bart import BartConfig
|
| 25 |
+
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
| 26 |
+
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
|
| 27 |
+
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
|
| 28 |
+
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
| 29 |
+
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
|
| 30 |
+
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
| 31 |
+
from .configuration_mmbt import MMBTConfig
|
| 32 |
+
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
|
| 33 |
+
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
| 34 |
+
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
| 35 |
+
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
|
| 36 |
+
|
| 37 |
+
# Configurations
|
| 38 |
+
from .configuration_utils import PretrainedConfig
|
| 39 |
+
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
|
| 40 |
+
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
| 41 |
+
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
|
| 42 |
+
from .data import (
|
| 43 |
+
DataProcessor,
|
| 44 |
+
InputExample,
|
| 45 |
+
InputFeatures,
|
| 46 |
+
SingleSentenceClassificationProcessor,
|
| 47 |
+
SquadExample,
|
| 48 |
+
SquadFeatures,
|
| 49 |
+
SquadV1Processor,
|
| 50 |
+
SquadV2Processor,
|
| 51 |
+
glue_convert_examples_to_features,
|
| 52 |
+
glue_output_modes,
|
| 53 |
+
glue_processors,
|
| 54 |
+
glue_tasks_num_labels,
|
| 55 |
+
is_sklearn_available,
|
| 56 |
+
squad_convert_examples_to_features,
|
| 57 |
+
xnli_output_modes,
|
| 58 |
+
xnli_processors,
|
| 59 |
+
xnli_tasks_num_labels,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Files and general utilities
|
| 63 |
+
from .file_utils import (
|
| 64 |
+
CONFIG_NAME,
|
| 65 |
+
MODEL_CARD_NAME,
|
| 66 |
+
PYTORCH_PRETRAINED_BERT_CACHE,
|
| 67 |
+
PYTORCH_TRANSFORMERS_CACHE,
|
| 68 |
+
TF2_WEIGHTS_NAME,
|
| 69 |
+
TF_WEIGHTS_NAME,
|
| 70 |
+
TRANSFORMERS_CACHE,
|
| 71 |
+
WEIGHTS_NAME,
|
| 72 |
+
add_end_docstrings,
|
| 73 |
+
add_start_docstrings,
|
| 74 |
+
cached_path,
|
| 75 |
+
is_tf_available,
|
| 76 |
+
is_torch_available,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Model Cards
|
| 80 |
+
from .modelcard import ModelCard
|
| 81 |
+
|
| 82 |
+
# TF 2.0 <=> PyTorch conversion utilities
|
| 83 |
+
from .modeling_tf_pytorch_utils import (
|
| 84 |
+
convert_tf_weight_name_to_pt_weight_name,
|
| 85 |
+
load_pytorch_checkpoint_in_tf2_model,
|
| 86 |
+
load_pytorch_model_in_tf2_model,
|
| 87 |
+
load_pytorch_weights_in_tf2_model,
|
| 88 |
+
load_tf2_checkpoint_in_pytorch_model,
|
| 89 |
+
load_tf2_model_in_pytorch_model,
|
| 90 |
+
load_tf2_weights_in_pytorch_model,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Pipelines
|
| 94 |
+
from .pipelines import (
|
| 95 |
+
CsvPipelineDataFormat,
|
| 96 |
+
FeatureExtractionPipeline,
|
| 97 |
+
FillMaskPipeline,
|
| 98 |
+
JsonPipelineDataFormat,
|
| 99 |
+
NerPipeline,
|
| 100 |
+
PipedPipelineDataFormat,
|
| 101 |
+
Pipeline,
|
| 102 |
+
PipelineDataFormat,
|
| 103 |
+
QuestionAnsweringPipeline,
|
| 104 |
+
TextClassificationPipeline,
|
| 105 |
+
TokenClassificationPipeline,
|
| 106 |
+
pipeline,
|
| 107 |
+
)
|
| 108 |
+
from .tokenization_albert import AlbertTokenizer
|
| 109 |
+
from .tokenization_auto import AutoTokenizer
|
| 110 |
+
from .tokenization_bart import BartTokenizer
|
| 111 |
+
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
|
| 112 |
+
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
|
| 113 |
+
from .tokenization_camembert import CamembertTokenizer
|
| 114 |
+
from .tokenization_ctrl import CTRLTokenizer
|
| 115 |
+
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
|
| 116 |
+
from .tokenization_flaubert import FlaubertTokenizer
|
| 117 |
+
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
|
| 118 |
+
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
|
| 119 |
+
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
|
| 120 |
+
from .tokenization_t5 import T5Tokenizer
|
| 121 |
+
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
|
| 122 |
+
from .tokenization_dna import DNATokenizer
|
| 123 |
+
|
| 124 |
+
# Tokenizers
|
| 125 |
+
from .tokenization_utils import PreTrainedTokenizer
|
| 126 |
+
from .tokenization_xlm import XLMTokenizer
|
| 127 |
+
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
| 128 |
+
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if is_sklearn_available():
|
| 135 |
+
from .data import glue_compute_metrics, xnli_compute_metrics
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Modeling
|
| 139 |
+
if is_torch_available():
|
| 140 |
+
from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
|
| 141 |
+
from .modeling_auto import (
|
| 142 |
+
AutoModel,
|
| 143 |
+
AutoModelForPreTraining,
|
| 144 |
+
AutoModelForSequenceClassification,
|
| 145 |
+
AutoModelForQuestionAnswering,
|
| 146 |
+
AutoModelWithLMHead,
|
| 147 |
+
AutoModelForTokenClassification,
|
| 148 |
+
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
from .modeling_bert import (
|
| 152 |
+
BertPreTrainedModel,
|
| 153 |
+
BertModel,
|
| 154 |
+
BertForPreTraining,
|
| 155 |
+
BertForMaskedLM,
|
| 156 |
+
BertForNextSentencePrediction,
|
| 157 |
+
BertForSequenceClassification,
|
| 158 |
+
BertForLongSequenceClassification,
|
| 159 |
+
BertForLongSequenceClassificationCat,
|
| 160 |
+
BertForMultipleChoice,
|
| 161 |
+
BertForTokenClassification,
|
| 162 |
+
BertForQuestionAnswering,
|
| 163 |
+
load_tf_weights_in_bert,
|
| 164 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 165 |
+
)
|
| 166 |
+
from .modeling_openai import (
|
| 167 |
+
OpenAIGPTPreTrainedModel,
|
| 168 |
+
OpenAIGPTModel,
|
| 169 |
+
OpenAIGPTLMHeadModel,
|
| 170 |
+
OpenAIGPTDoubleHeadsModel,
|
| 171 |
+
load_tf_weights_in_openai_gpt,
|
| 172 |
+
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 173 |
+
)
|
| 174 |
+
from .modeling_transfo_xl import (
|
| 175 |
+
TransfoXLPreTrainedModel,
|
| 176 |
+
TransfoXLModel,
|
| 177 |
+
TransfoXLLMHeadModel,
|
| 178 |
+
AdaptiveEmbedding,
|
| 179 |
+
load_tf_weights_in_transfo_xl,
|
| 180 |
+
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 181 |
+
)
|
| 182 |
+
from .modeling_gpt2 import (
|
| 183 |
+
GPT2PreTrainedModel,
|
| 184 |
+
GPT2Model,
|
| 185 |
+
GPT2LMHeadModel,
|
| 186 |
+
GPT2DoubleHeadsModel,
|
| 187 |
+
load_tf_weights_in_gpt2,
|
| 188 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 189 |
+
)
|
| 190 |
+
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
| 191 |
+
from .modeling_xlnet import (
|
| 192 |
+
XLNetPreTrainedModel,
|
| 193 |
+
XLNetModel,
|
| 194 |
+
XLNetLMHeadModel,
|
| 195 |
+
XLNetForSequenceClassification,
|
| 196 |
+
XLNetForTokenClassification,
|
| 197 |
+
XLNetForMultipleChoice,
|
| 198 |
+
XLNetForQuestionAnsweringSimple,
|
| 199 |
+
XLNetForQuestionAnswering,
|
| 200 |
+
load_tf_weights_in_xlnet,
|
| 201 |
+
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 202 |
+
)
|
| 203 |
+
from .modeling_xlm import (
|
| 204 |
+
XLMPreTrainedModel,
|
| 205 |
+
XLMModel,
|
| 206 |
+
XLMWithLMHeadModel,
|
| 207 |
+
XLMForSequenceClassification,
|
| 208 |
+
XLMForQuestionAnswering,
|
| 209 |
+
XLMForQuestionAnsweringSimple,
|
| 210 |
+
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 211 |
+
)
|
| 212 |
+
from .modeling_bart import BartForSequenceClassification, BartModel, BartForMaskedLM
|
| 213 |
+
from .modeling_roberta import (
|
| 214 |
+
RobertaForMaskedLM,
|
| 215 |
+
RobertaModel,
|
| 216 |
+
RobertaForSequenceClassification,
|
| 217 |
+
RobertaForMultipleChoice,
|
| 218 |
+
RobertaForTokenClassification,
|
| 219 |
+
RobertaForQuestionAnswering,
|
| 220 |
+
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 221 |
+
)
|
| 222 |
+
from .modeling_camembert import (
|
| 223 |
+
CamembertForMaskedLM,
|
| 224 |
+
CamembertModel,
|
| 225 |
+
CamembertForSequenceClassification,
|
| 226 |
+
CamembertForTokenClassification,
|
| 227 |
+
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 228 |
+
)
|
| 229 |
+
from .modeling_distilbert import (
|
| 230 |
+
DistilBertPreTrainedModel,
|
| 231 |
+
DistilBertForMaskedLM,
|
| 232 |
+
DistilBertModel,
|
| 233 |
+
DistilBertForSequenceClassification,
|
| 234 |
+
DistilBertForQuestionAnswering,
|
| 235 |
+
DistilBertForTokenClassification,
|
| 236 |
+
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 237 |
+
)
|
| 238 |
+
from .modeling_camembert import (
|
| 239 |
+
CamembertForMaskedLM,
|
| 240 |
+
CamembertModel,
|
| 241 |
+
CamembertForSequenceClassification,
|
| 242 |
+
CamembertForMultipleChoice,
|
| 243 |
+
CamembertForTokenClassification,
|
| 244 |
+
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 245 |
+
)
|
| 246 |
+
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
|
| 247 |
+
from .modeling_t5 import (
|
| 248 |
+
T5PreTrainedModel,
|
| 249 |
+
T5Model,
|
| 250 |
+
T5WithLMHeadModel,
|
| 251 |
+
load_tf_weights_in_t5,
|
| 252 |
+
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 253 |
+
)
|
| 254 |
+
from .modeling_albert import (
|
| 255 |
+
AlbertPreTrainedModel,
|
| 256 |
+
AlbertModel,
|
| 257 |
+
AlbertForMaskedLM,
|
| 258 |
+
AlbertForSequenceClassification,
|
| 259 |
+
AlbertForQuestionAnswering,
|
| 260 |
+
load_tf_weights_in_albert,
|
| 261 |
+
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 262 |
+
)
|
| 263 |
+
from .modeling_xlm_roberta import (
|
| 264 |
+
XLMRobertaForMaskedLM,
|
| 265 |
+
XLMRobertaModel,
|
| 266 |
+
XLMRobertaForMultipleChoice,
|
| 267 |
+
XLMRobertaForSequenceClassification,
|
| 268 |
+
XLMRobertaForTokenClassification,
|
| 269 |
+
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 270 |
+
)
|
| 271 |
+
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
|
| 272 |
+
|
| 273 |
+
from .modeling_flaubert import (
|
| 274 |
+
FlaubertModel,
|
| 275 |
+
FlaubertWithLMHeadModel,
|
| 276 |
+
FlaubertForSequenceClassification,
|
| 277 |
+
FlaubertForQuestionAnswering,
|
| 278 |
+
FlaubertForQuestionAnsweringSimple,
|
| 279 |
+
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Optimization
|
| 283 |
+
from .optimization import (
|
| 284 |
+
AdamW,
|
| 285 |
+
get_constant_schedule,
|
| 286 |
+
get_constant_schedule_with_warmup,
|
| 287 |
+
get_cosine_schedule_with_warmup,
|
| 288 |
+
get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 289 |
+
get_linear_schedule_with_warmup,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# TensorFlow
|
| 294 |
+
if is_tf_available():
|
| 295 |
+
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
|
| 296 |
+
from .modeling_tf_auto import (
|
| 297 |
+
TFAutoModel,
|
| 298 |
+
TFAutoModelForPreTraining,
|
| 299 |
+
TFAutoModelForSequenceClassification,
|
| 300 |
+
TFAutoModelForQuestionAnswering,
|
| 301 |
+
TFAutoModelWithLMHead,
|
| 302 |
+
TFAutoModelForTokenClassification,
|
| 303 |
+
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
from .modeling_tf_bert import (
|
| 307 |
+
TFBertPreTrainedModel,
|
| 308 |
+
TFBertMainLayer,
|
| 309 |
+
TFBertEmbeddings,
|
| 310 |
+
TFBertModel,
|
| 311 |
+
TFBertForPreTraining,
|
| 312 |
+
TFBertForMaskedLM,
|
| 313 |
+
TFBertForNextSentencePrediction,
|
| 314 |
+
TFBertForSequenceClassification,
|
| 315 |
+
TFBertForMultipleChoice,
|
| 316 |
+
TFBertForTokenClassification,
|
| 317 |
+
TFBertForQuestionAnswering,
|
| 318 |
+
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
from .modeling_tf_gpt2 import (
|
| 322 |
+
TFGPT2PreTrainedModel,
|
| 323 |
+
TFGPT2MainLayer,
|
| 324 |
+
TFGPT2Model,
|
| 325 |
+
TFGPT2LMHeadModel,
|
| 326 |
+
TFGPT2DoubleHeadsModel,
|
| 327 |
+
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
from .modeling_tf_openai import (
|
| 331 |
+
TFOpenAIGPTPreTrainedModel,
|
| 332 |
+
TFOpenAIGPTMainLayer,
|
| 333 |
+
TFOpenAIGPTModel,
|
| 334 |
+
TFOpenAIGPTLMHeadModel,
|
| 335 |
+
TFOpenAIGPTDoubleHeadsModel,
|
| 336 |
+
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
from .modeling_tf_transfo_xl import (
|
| 340 |
+
TFTransfoXLPreTrainedModel,
|
| 341 |
+
TFTransfoXLMainLayer,
|
| 342 |
+
TFTransfoXLModel,
|
| 343 |
+
TFTransfoXLLMHeadModel,
|
| 344 |
+
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
from .modeling_tf_xlnet import (
|
| 348 |
+
TFXLNetPreTrainedModel,
|
| 349 |
+
TFXLNetMainLayer,
|
| 350 |
+
TFXLNetModel,
|
| 351 |
+
TFXLNetLMHeadModel,
|
| 352 |
+
TFXLNetForSequenceClassification,
|
| 353 |
+
TFXLNetForTokenClassification,
|
| 354 |
+
TFXLNetForQuestionAnsweringSimple,
|
| 355 |
+
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
from .modeling_tf_xlm import (
|
| 359 |
+
TFXLMPreTrainedModel,
|
| 360 |
+
TFXLMMainLayer,
|
| 361 |
+
TFXLMModel,
|
| 362 |
+
TFXLMWithLMHeadModel,
|
| 363 |
+
TFXLMForSequenceClassification,
|
| 364 |
+
TFXLMForQuestionAnsweringSimple,
|
| 365 |
+
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
from .modeling_tf_xlm_roberta import (
|
| 369 |
+
TFXLMRobertaForMaskedLM,
|
| 370 |
+
TFXLMRobertaModel,
|
| 371 |
+
TFXLMRobertaForSequenceClassification,
|
| 372 |
+
TFXLMRobertaForTokenClassification,
|
| 373 |
+
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
from .modeling_tf_roberta import (
|
| 377 |
+
TFRobertaPreTrainedModel,
|
| 378 |
+
TFRobertaMainLayer,
|
| 379 |
+
TFRobertaModel,
|
| 380 |
+
TFRobertaForMaskedLM,
|
| 381 |
+
TFRobertaForSequenceClassification,
|
| 382 |
+
TFRobertaForTokenClassification,
|
| 383 |
+
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
from .modeling_tf_camembert import (
|
| 387 |
+
TFCamembertModel,
|
| 388 |
+
TFCamembertForMaskedLM,
|
| 389 |
+
TFCamembertForSequenceClassification,
|
| 390 |
+
TFCamembertForTokenClassification,
|
| 391 |
+
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
from .modeling_tf_distilbert import (
|
| 395 |
+
TFDistilBertPreTrainedModel,
|
| 396 |
+
TFDistilBertMainLayer,
|
| 397 |
+
TFDistilBertModel,
|
| 398 |
+
TFDistilBertForMaskedLM,
|
| 399 |
+
TFDistilBertForSequenceClassification,
|
| 400 |
+
TFDistilBertForTokenClassification,
|
| 401 |
+
TFDistilBertForQuestionAnswering,
|
| 402 |
+
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
from .modeling_tf_ctrl import (
|
| 406 |
+
TFCTRLPreTrainedModel,
|
| 407 |
+
TFCTRLModel,
|
| 408 |
+
TFCTRLLMHeadModel,
|
| 409 |
+
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
from .modeling_tf_albert import (
|
| 413 |
+
TFAlbertPreTrainedModel,
|
| 414 |
+
TFAlbertModel,
|
| 415 |
+
TFAlbertForMaskedLM,
|
| 416 |
+
TFAlbertForSequenceClassification,
|
| 417 |
+
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
from .modeling_tf_t5 import (
|
| 421 |
+
TFT5PreTrainedModel,
|
| 422 |
+
TFT5Model,
|
| 423 |
+
TFT5WithLMHeadModel,
|
| 424 |
+
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Optimization
|
| 428 |
+
from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if not is_tf_available() and not is_torch_available():
|
| 432 |
+
logger.warning(
|
| 433 |
+
"Neither PyTorch nor TensorFlow >= 2.0 have been found."
|
| 434 |
+
"Models won't be available and only tokenizers, configuration"
|
| 435 |
+
"and file/data utilities can be used."
|
| 436 |
+
)
|
src/transformers/activations.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def swish(x):
|
| 8 |
+
return x * torch.sigmoid(x)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _gelu_python(x):
|
| 12 |
+
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
| 13 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 14 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 15 |
+
This is now written in C in torch.nn.functional
|
| 16 |
+
Also see https://arxiv.org/abs/1606.08415
|
| 17 |
+
"""
|
| 18 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
gelu = getattr(F, "gelu", _gelu_python)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def gelu_new(x):
|
| 25 |
+
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
| 26 |
+
Also see https://arxiv.org/abs/1606.08415
|
| 27 |
+
"""
|
| 28 |
+
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ACT2FN = {
|
| 32 |
+
"relu": F.relu,
|
| 33 |
+
"swish": swish,
|
| 34 |
+
"gelu": gelu,
|
| 35 |
+
"tanh": F.tanh,
|
| 36 |
+
"gelu_new": gelu_new,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_activation(activation_string):
|
| 41 |
+
if activation_string in ACT2FN:
|
| 42 |
+
return ACT2FN[activation_string]
|
| 43 |
+
else:
|
| 44 |
+
raise KeyError(
|
| 45 |
+
"function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
|
| 46 |
+
activation_string, list(ACT2FN.keys())
|
| 47 |
+
)
|
| 48 |
+
)
|
src/transformers/commands/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BaseTransformersCLICommand(ABC):
|
| 6 |
+
@staticmethod
|
| 7 |
+
@abstractmethod
|
| 8 |
+
def register_subcommand(parser: ArgumentParser):
|
| 9 |
+
raise NotImplementedError()
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def run(self):
|
| 13 |
+
raise NotImplementedError()
|
src/transformers/commands/convert.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser, Namespace
|
| 2 |
+
from logging import getLogger
|
| 3 |
+
|
| 4 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def convert_command_factory(args: Namespace):
|
| 8 |
+
"""
|
| 9 |
+
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
|
| 10 |
+
:return: ServeCommand
|
| 11 |
+
"""
|
| 12 |
+
return ConvertCommand(
|
| 13 |
+
args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ConvertCommand(BaseTransformersCLICommand):
|
| 18 |
+
@staticmethod
|
| 19 |
+
def register_subcommand(parser: ArgumentParser):
|
| 20 |
+
"""
|
| 21 |
+
Register this command to argparse so it's available for the transformer-cli
|
| 22 |
+
:param parser: Root parser to register command-specific arguments
|
| 23 |
+
:return:
|
| 24 |
+
"""
|
| 25 |
+
train_parser = parser.add_parser(
|
| 26 |
+
"convert",
|
| 27 |
+
help="CLI tool to run convert model from original "
|
| 28 |
+
"author checkpoints to Transformers PyTorch checkpoints.",
|
| 29 |
+
)
|
| 30 |
+
train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
|
| 31 |
+
train_parser.add_argument(
|
| 32 |
+
"--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
|
| 33 |
+
)
|
| 34 |
+
train_parser.add_argument(
|
| 35 |
+
"--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
|
| 36 |
+
)
|
| 37 |
+
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
|
| 38 |
+
train_parser.add_argument(
|
| 39 |
+
"--finetuning_task_name",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
help="Optional fine-tuning task name if the TF model was a finetuned model.",
|
| 43 |
+
)
|
| 44 |
+
train_parser.set_defaults(func=convert_command_factory)
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
model_type: str,
|
| 49 |
+
tf_checkpoint: str,
|
| 50 |
+
pytorch_dump_output: str,
|
| 51 |
+
config: str,
|
| 52 |
+
finetuning_task_name: str,
|
| 53 |
+
*args
|
| 54 |
+
):
|
| 55 |
+
self._logger = getLogger("transformers-cli/converting")
|
| 56 |
+
|
| 57 |
+
self._logger.info("Loading model {}".format(model_type))
|
| 58 |
+
self._model_type = model_type
|
| 59 |
+
self._tf_checkpoint = tf_checkpoint
|
| 60 |
+
self._pytorch_dump_output = pytorch_dump_output
|
| 61 |
+
self._config = config
|
| 62 |
+
self._finetuning_task_name = finetuning_task_name
|
| 63 |
+
|
| 64 |
+
def run(self):
|
| 65 |
+
if self._model_type == "bert":
|
| 66 |
+
try:
|
| 67 |
+
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
| 68 |
+
convert_tf_checkpoint_to_pytorch,
|
| 69 |
+
)
|
| 70 |
+
except ImportError:
|
| 71 |
+
msg = (
|
| 72 |
+
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
| 73 |
+
"In that case, it requires TensorFlow to be installed. Please see "
|
| 74 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 75 |
+
)
|
| 76 |
+
raise ImportError(msg)
|
| 77 |
+
|
| 78 |
+
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
| 79 |
+
elif self._model_type == "gpt":
|
| 80 |
+
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
|
| 81 |
+
convert_openai_checkpoint_to_pytorch,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
| 85 |
+
elif self._model_type == "transfo_xl":
|
| 86 |
+
try:
|
| 87 |
+
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
| 88 |
+
convert_transfo_xl_checkpoint_to_pytorch,
|
| 89 |
+
)
|
| 90 |
+
except ImportError:
|
| 91 |
+
msg = (
|
| 92 |
+
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
| 93 |
+
"In that case, it requires TensorFlow to be installed. Please see "
|
| 94 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 95 |
+
)
|
| 96 |
+
raise ImportError(msg)
|
| 97 |
+
|
| 98 |
+
if "ckpt" in self._tf_checkpoint.lower():
|
| 99 |
+
TF_CHECKPOINT = self._tf_checkpoint
|
| 100 |
+
TF_DATASET_FILE = ""
|
| 101 |
+
else:
|
| 102 |
+
TF_DATASET_FILE = self._tf_checkpoint
|
| 103 |
+
TF_CHECKPOINT = ""
|
| 104 |
+
convert_transfo_xl_checkpoint_to_pytorch(
|
| 105 |
+
TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
|
| 106 |
+
)
|
| 107 |
+
elif self._model_type == "gpt2":
|
| 108 |
+
try:
|
| 109 |
+
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
| 110 |
+
convert_gpt2_checkpoint_to_pytorch,
|
| 111 |
+
)
|
| 112 |
+
except ImportError:
|
| 113 |
+
msg = (
|
| 114 |
+
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
| 115 |
+
"In that case, it requires TensorFlow to be installed. Please see "
|
| 116 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 117 |
+
)
|
| 118 |
+
raise ImportError(msg)
|
| 119 |
+
|
| 120 |
+
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
| 121 |
+
elif self._model_type == "xlnet":
|
| 122 |
+
try:
|
| 123 |
+
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
| 124 |
+
convert_xlnet_checkpoint_to_pytorch,
|
| 125 |
+
)
|
| 126 |
+
except ImportError:
|
| 127 |
+
msg = (
|
| 128 |
+
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
| 129 |
+
"In that case, it requires TensorFlow to be installed. Please see "
|
| 130 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 131 |
+
)
|
| 132 |
+
raise ImportError(msg)
|
| 133 |
+
|
| 134 |
+
convert_xlnet_checkpoint_to_pytorch(
|
| 135 |
+
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
| 136 |
+
)
|
| 137 |
+
elif self._model_type == "xlm":
|
| 138 |
+
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
| 139 |
+
convert_xlm_checkpoint_to_pytorch,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
| 143 |
+
else:
|
| 144 |
+
raise ValueError("--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm]")
|
src/transformers/commands/download.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from argparse import ArgumentParser
|
| 2 |
+
|
| 3 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def download_command_factory(args):
|
| 7 |
+
return DownloadCommand(args.model, args.cache_dir, args.force)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DownloadCommand(BaseTransformersCLICommand):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def register_subcommand(parser: ArgumentParser):
|
| 13 |
+
download_parser = parser.add_parser("download")
|
| 14 |
+
download_parser.add_argument(
|
| 15 |
+
"--cache-dir", type=str, default=None, help="Path to location to store the models"
|
| 16 |
+
)
|
| 17 |
+
download_parser.add_argument(
|
| 18 |
+
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
|
| 19 |
+
)
|
| 20 |
+
download_parser.add_argument("model", type=str, help="Name of the model to download")
|
| 21 |
+
download_parser.set_defaults(func=download_command_factory)
|
| 22 |
+
|
| 23 |
+
def __init__(self, model: str, cache: str, force: bool):
|
| 24 |
+
self._model = model
|
| 25 |
+
self._cache = cache
|
| 26 |
+
self._force = force
|
| 27 |
+
|
| 28 |
+
def run(self):
|
| 29 |
+
from transformers import AutoModel, AutoTokenizer
|
| 30 |
+
|
| 31 |
+
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
| 32 |
+
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
src/transformers/commands/env.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import platform
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
|
| 4 |
+
from transformers import __version__ as version
|
| 5 |
+
from transformers import is_tf_available, is_torch_available
|
| 6 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def info_command_factory(_):
|
| 10 |
+
return EnvironmentCommand()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EnvironmentCommand(BaseTransformersCLICommand):
|
| 14 |
+
@staticmethod
|
| 15 |
+
def register_subcommand(parser: ArgumentParser):
|
| 16 |
+
download_parser = parser.add_parser("env")
|
| 17 |
+
download_parser.set_defaults(func=info_command_factory)
|
| 18 |
+
|
| 19 |
+
def run(self):
|
| 20 |
+
pt_version = "not installed"
|
| 21 |
+
pt_cuda_available = "NA"
|
| 22 |
+
if is_torch_available():
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
pt_version = torch.__version__
|
| 26 |
+
pt_cuda_available = torch.cuda.is_available()
|
| 27 |
+
|
| 28 |
+
tf_version = "not installed"
|
| 29 |
+
tf_cuda_available = "NA"
|
| 30 |
+
if is_tf_available():
|
| 31 |
+
import tensorflow as tf
|
| 32 |
+
|
| 33 |
+
tf_version = tf.__version__
|
| 34 |
+
try:
|
| 35 |
+
# deprecated in v2.1
|
| 36 |
+
tf_cuda_available = tf.test.is_gpu_available()
|
| 37 |
+
except AttributeError:
|
| 38 |
+
# returns list of devices, convert to bool
|
| 39 |
+
tf_cuda_available = bool(tf.config.list_physical_devices("GPU"))
|
| 40 |
+
|
| 41 |
+
info = {
|
| 42 |
+
"`transformers` version": version,
|
| 43 |
+
"Platform": platform.platform(),
|
| 44 |
+
"Python version": platform.python_version(),
|
| 45 |
+
"PyTorch version (GPU?)": "{} ({})".format(pt_version, pt_cuda_available),
|
| 46 |
+
"Tensorflow version (GPU?)": "{} ({})".format(tf_version, tf_cuda_available),
|
| 47 |
+
"Using GPU in script?": "<fill in>",
|
| 48 |
+
"Using distributed or parallel set-up in script?": "<fill in>",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
|
| 52 |
+
print(self.format_dict(info))
|
| 53 |
+
|
| 54 |
+
return info
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def format_dict(d):
|
| 58 |
+
return "\n".join(["- {}: {}".format(prop, val) for prop, val in d.items()]) + "\n"
|
src/transformers/commands/run.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
|
| 4 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 5 |
+
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def try_infer_format_from_ext(path: str):
|
| 12 |
+
if not path:
|
| 13 |
+
return "pipe"
|
| 14 |
+
|
| 15 |
+
for ext in PipelineDataFormat.SUPPORTED_FORMATS:
|
| 16 |
+
if path.endswith(ext):
|
| 17 |
+
return ext
|
| 18 |
+
|
| 19 |
+
raise Exception(
|
| 20 |
+
"Unable to determine file format from file extension {}. "
|
| 21 |
+
"Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_command_factory(args):
|
| 26 |
+
nlp = pipeline(
|
| 27 |
+
task=args.task,
|
| 28 |
+
model=args.model if args.model else None,
|
| 29 |
+
config=args.config,
|
| 30 |
+
tokenizer=args.tokenizer,
|
| 31 |
+
device=args.device,
|
| 32 |
+
)
|
| 33 |
+
format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
|
| 34 |
+
reader = PipelineDataFormat.from_str(
|
| 35 |
+
format=format,
|
| 36 |
+
output_path=args.output,
|
| 37 |
+
input_path=args.input,
|
| 38 |
+
column=args.column if args.column else nlp.default_input_names,
|
| 39 |
+
overwrite=args.overwrite,
|
| 40 |
+
)
|
| 41 |
+
return RunCommand(nlp, reader)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class RunCommand(BaseTransformersCLICommand):
|
| 45 |
+
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
|
| 46 |
+
self._nlp = nlp
|
| 47 |
+
self._reader = reader
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def register_subcommand(parser: ArgumentParser):
|
| 51 |
+
run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
|
| 52 |
+
run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
|
| 53 |
+
run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
|
| 54 |
+
run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
|
| 55 |
+
run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
|
| 56 |
+
run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
|
| 57 |
+
run_parser.add_argument(
|
| 58 |
+
"--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
|
| 59 |
+
)
|
| 60 |
+
run_parser.add_argument(
|
| 61 |
+
"--column",
|
| 62 |
+
type=str,
|
| 63 |
+
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
|
| 64 |
+
)
|
| 65 |
+
run_parser.add_argument(
|
| 66 |
+
"--format",
|
| 67 |
+
type=str,
|
| 68 |
+
default="infer",
|
| 69 |
+
choices=PipelineDataFormat.SUPPORTED_FORMATS,
|
| 70 |
+
help="Input format to read from",
|
| 71 |
+
)
|
| 72 |
+
run_parser.add_argument(
|
| 73 |
+
"--device",
|
| 74 |
+
type=int,
|
| 75 |
+
default=-1,
|
| 76 |
+
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
| 77 |
+
)
|
| 78 |
+
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
|
| 79 |
+
run_parser.set_defaults(func=run_command_factory)
|
| 80 |
+
|
| 81 |
+
def run(self):
|
| 82 |
+
nlp, outputs = self._nlp, []
|
| 83 |
+
|
| 84 |
+
for entry in self._reader:
|
| 85 |
+
output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
|
| 86 |
+
if isinstance(output, dict):
|
| 87 |
+
outputs.append(output)
|
| 88 |
+
else:
|
| 89 |
+
outputs += output
|
| 90 |
+
|
| 91 |
+
# Saving data
|
| 92 |
+
if self._nlp.binary_output:
|
| 93 |
+
binary_path = self._reader.save_binary(outputs)
|
| 94 |
+
logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
|
| 95 |
+
else:
|
| 96 |
+
self._reader.save(outputs)
|
src/transformers/commands/serving.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from argparse import ArgumentParser, Namespace
|
| 3 |
+
from typing import Any, List, Optional
|
| 4 |
+
|
| 5 |
+
from transformers import Pipeline
|
| 6 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 7 |
+
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from uvicorn import run
|
| 12 |
+
from fastapi import FastAPI, HTTPException, Body
|
| 13 |
+
from fastapi.routing import APIRoute
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
from starlette.responses import JSONResponse
|
| 16 |
+
|
| 17 |
+
_serve_dependencies_installed = True
|
| 18 |
+
except (ImportError, AttributeError):
|
| 19 |
+
BaseModel = object
|
| 20 |
+
|
| 21 |
+
def Body(*x, **y):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
_serve_dependencies_installed = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger("transformers-cli/serving")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def serve_command_factory(args: Namespace):
|
| 31 |
+
"""
|
| 32 |
+
Factory function used to instantiate serving server from provided command line arguments.
|
| 33 |
+
:return: ServeCommand
|
| 34 |
+
"""
|
| 35 |
+
nlp = pipeline(
|
| 36 |
+
task=args.task,
|
| 37 |
+
model=args.model if args.model else None,
|
| 38 |
+
config=args.config,
|
| 39 |
+
tokenizer=args.tokenizer,
|
| 40 |
+
device=args.device,
|
| 41 |
+
)
|
| 42 |
+
return ServeCommand(nlp, args.host, args.port, args.workers)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ServeModelInfoResult(BaseModel):
|
| 46 |
+
"""
|
| 47 |
+
Expose model information
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
infos: dict
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class ServeTokenizeResult(BaseModel):
|
| 54 |
+
"""
|
| 55 |
+
Tokenize result model
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
tokens: List[str]
|
| 59 |
+
tokens_ids: Optional[List[int]]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ServeDeTokenizeResult(BaseModel):
|
| 63 |
+
"""
|
| 64 |
+
DeTokenize result model
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
text: str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ServeForwardResult(BaseModel):
|
| 71 |
+
"""
|
| 72 |
+
Forward result model
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
output: Any
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ServeCommand(BaseTransformersCLICommand):
|
| 79 |
+
@staticmethod
|
| 80 |
+
def register_subcommand(parser: ArgumentParser):
|
| 81 |
+
"""
|
| 82 |
+
Register this command to argparse so it's available for the transformer-cli
|
| 83 |
+
:param parser: Root parser to register command-specific arguments
|
| 84 |
+
:return:
|
| 85 |
+
"""
|
| 86 |
+
serve_parser = parser.add_parser(
|
| 87 |
+
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
| 88 |
+
)
|
| 89 |
+
serve_parser.add_argument(
|
| 90 |
+
"--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
|
| 91 |
+
)
|
| 92 |
+
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
| 93 |
+
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
| 94 |
+
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
| 95 |
+
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
| 96 |
+
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
| 97 |
+
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
| 98 |
+
serve_parser.add_argument(
|
| 99 |
+
"--device",
|
| 100 |
+
type=int,
|
| 101 |
+
default=-1,
|
| 102 |
+
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
| 103 |
+
)
|
| 104 |
+
serve_parser.set_defaults(func=serve_command_factory)
|
| 105 |
+
|
| 106 |
+
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
| 107 |
+
|
| 108 |
+
self._pipeline = pipeline
|
| 109 |
+
|
| 110 |
+
self.host = host
|
| 111 |
+
self.port = port
|
| 112 |
+
self.workers = workers
|
| 113 |
+
|
| 114 |
+
if not _serve_dependencies_installed:
|
| 115 |
+
raise RuntimeError(
|
| 116 |
+
"Using serve command requires FastAPI and unicorn. "
|
| 117 |
+
'Please install transformers with [serving]: pip install "transformers[serving]".'
|
| 118 |
+
"Or install FastAPI and unicorn separately."
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
logger.info("Serving model over {}:{}".format(host, port))
|
| 122 |
+
self._app = FastAPI(
|
| 123 |
+
routes=[
|
| 124 |
+
APIRoute(
|
| 125 |
+
"/",
|
| 126 |
+
self.model_info,
|
| 127 |
+
response_model=ServeModelInfoResult,
|
| 128 |
+
response_class=JSONResponse,
|
| 129 |
+
methods=["GET"],
|
| 130 |
+
),
|
| 131 |
+
APIRoute(
|
| 132 |
+
"/tokenize",
|
| 133 |
+
self.tokenize,
|
| 134 |
+
response_model=ServeTokenizeResult,
|
| 135 |
+
response_class=JSONResponse,
|
| 136 |
+
methods=["POST"],
|
| 137 |
+
),
|
| 138 |
+
APIRoute(
|
| 139 |
+
"/detokenize",
|
| 140 |
+
self.detokenize,
|
| 141 |
+
response_model=ServeDeTokenizeResult,
|
| 142 |
+
response_class=JSONResponse,
|
| 143 |
+
methods=["POST"],
|
| 144 |
+
),
|
| 145 |
+
APIRoute(
|
| 146 |
+
"/forward",
|
| 147 |
+
self.forward,
|
| 148 |
+
response_model=ServeForwardResult,
|
| 149 |
+
response_class=JSONResponse,
|
| 150 |
+
methods=["POST"],
|
| 151 |
+
),
|
| 152 |
+
],
|
| 153 |
+
timeout=600,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def run(self):
|
| 157 |
+
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
| 158 |
+
|
| 159 |
+
def model_info(self):
|
| 160 |
+
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
| 161 |
+
|
| 162 |
+
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
| 163 |
+
"""
|
| 164 |
+
Tokenize the provided input and eventually returns corresponding tokens id:
|
| 165 |
+
- **text_input**: String to tokenize
|
| 166 |
+
- **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer mapping.
|
| 167 |
+
"""
|
| 168 |
+
try:
|
| 169 |
+
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
| 170 |
+
|
| 171 |
+
if return_ids:
|
| 172 |
+
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
| 173 |
+
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
| 174 |
+
else:
|
| 175 |
+
return ServeTokenizeResult(tokens=tokens_txt)
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
| 179 |
+
|
| 180 |
+
def detokenize(
|
| 181 |
+
self,
|
| 182 |
+
tokens_ids: List[int] = Body(None, embed=True),
|
| 183 |
+
skip_special_tokens: bool = Body(False, embed=True),
|
| 184 |
+
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
| 185 |
+
):
|
| 186 |
+
"""
|
| 187 |
+
Detokenize the provided tokens ids to readable text:
|
| 188 |
+
- **tokens_ids**: List of tokens ids
|
| 189 |
+
- **skip_special_tokens**: Flag indicating to not try to decode special tokens
|
| 190 |
+
- **cleanup_tokenization_spaces**: Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
| 191 |
+
"""
|
| 192 |
+
try:
|
| 193 |
+
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
| 194 |
+
return ServeDeTokenizeResult(model="", text=decoded_str)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
| 197 |
+
|
| 198 |
+
async def forward(self, inputs=Body(None, embed=True)):
|
| 199 |
+
"""
|
| 200 |
+
**inputs**:
|
| 201 |
+
**attention_mask**:
|
| 202 |
+
**tokens_type_ids**:
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
# Check we don't have empty string
|
| 206 |
+
if len(inputs) == 0:
|
| 207 |
+
return ServeForwardResult(output=[], attention=[])
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
# Forward through the model
|
| 211 |
+
output = self._pipeline(inputs)
|
| 212 |
+
return ServeForwardResult(output=output)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
raise HTTPException(500, {"error": str(e)})
|
src/transformers/commands/train.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from argparse import ArgumentParser, Namespace
|
| 3 |
+
from logging import getLogger
|
| 4 |
+
|
| 5 |
+
from transformers import SingleSentenceClassificationProcessor as Processor
|
| 6 |
+
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
| 7 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if not is_tf_available() and not is_torch_available():
|
| 11 |
+
raise RuntimeError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
|
| 12 |
+
|
| 13 |
+
# TF training parameters
|
| 14 |
+
USE_XLA = False
|
| 15 |
+
USE_AMP = False
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def train_command_factory(args: Namespace):
|
| 19 |
+
"""
|
| 20 |
+
Factory function used to instantiate serving server from provided command line arguments.
|
| 21 |
+
:return: ServeCommand
|
| 22 |
+
"""
|
| 23 |
+
return TrainCommand(args)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TrainCommand(BaseTransformersCLICommand):
|
| 27 |
+
@staticmethod
|
| 28 |
+
def register_subcommand(parser: ArgumentParser):
|
| 29 |
+
"""
|
| 30 |
+
Register this command to argparse so it's available for the transformer-cli
|
| 31 |
+
:param parser: Root parser to register command-specific arguments
|
| 32 |
+
:return:
|
| 33 |
+
"""
|
| 34 |
+
train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
|
| 35 |
+
|
| 36 |
+
train_parser.add_argument(
|
| 37 |
+
"--train_data",
|
| 38 |
+
type=str,
|
| 39 |
+
required=True,
|
| 40 |
+
help="path to train (and optionally evaluation) dataset as a csv with "
|
| 41 |
+
"tab separated labels and sentences.",
|
| 42 |
+
)
|
| 43 |
+
train_parser.add_argument(
|
| 44 |
+
"--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
|
| 45 |
+
)
|
| 46 |
+
train_parser.add_argument(
|
| 47 |
+
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
|
| 48 |
+
)
|
| 49 |
+
train_parser.add_argument(
|
| 50 |
+
"--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
|
| 51 |
+
)
|
| 52 |
+
train_parser.add_argument(
|
| 53 |
+
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
|
| 57 |
+
train_parser.add_argument(
|
| 58 |
+
"--validation_split",
|
| 59 |
+
type=float,
|
| 60 |
+
default=0.1,
|
| 61 |
+
help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
|
| 65 |
+
|
| 66 |
+
train_parser.add_argument(
|
| 67 |
+
"--task", type=str, default="text_classification", help="Task to train the model on."
|
| 68 |
+
)
|
| 69 |
+
train_parser.add_argument(
|
| 70 |
+
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
|
| 71 |
+
)
|
| 72 |
+
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
|
| 73 |
+
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
|
| 74 |
+
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
|
| 75 |
+
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
|
| 76 |
+
train_parser.set_defaults(func=train_command_factory)
|
| 77 |
+
|
| 78 |
+
def __init__(self, args: Namespace):
|
| 79 |
+
self.logger = getLogger("transformers-cli/training")
|
| 80 |
+
|
| 81 |
+
self.framework = "tf" if is_tf_available() else "torch"
|
| 82 |
+
|
| 83 |
+
os.makedirs(args.output, exist_ok=True)
|
| 84 |
+
assert os.path.isdir(args.output)
|
| 85 |
+
self.output = args.output
|
| 86 |
+
|
| 87 |
+
self.column_label = args.column_label
|
| 88 |
+
self.column_text = args.column_text
|
| 89 |
+
self.column_id = args.column_id
|
| 90 |
+
|
| 91 |
+
self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
|
| 92 |
+
if args.task == "text_classification":
|
| 93 |
+
self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
|
| 94 |
+
elif args.task == "token_classification":
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
elif args.task == "question_answering":
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
|
| 99 |
+
self.logger.info("Loading dataset from {}".format(args.train_data))
|
| 100 |
+
self.train_dataset = Processor.create_from_csv(
|
| 101 |
+
args.train_data,
|
| 102 |
+
column_label=args.column_label,
|
| 103 |
+
column_text=args.column_text,
|
| 104 |
+
column_id=args.column_id,
|
| 105 |
+
skip_first_row=args.skip_first_row,
|
| 106 |
+
)
|
| 107 |
+
self.valid_dataset = None
|
| 108 |
+
if args.validation_data:
|
| 109 |
+
self.logger.info("Loading validation dataset from {}".format(args.validation_data))
|
| 110 |
+
self.valid_dataset = Processor.create_from_csv(
|
| 111 |
+
args.validation_data,
|
| 112 |
+
column_label=args.column_label,
|
| 113 |
+
column_text=args.column_text,
|
| 114 |
+
column_id=args.column_id,
|
| 115 |
+
skip_first_row=args.skip_first_row,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
self.validation_split = args.validation_split
|
| 119 |
+
self.train_batch_size = args.train_batch_size
|
| 120 |
+
self.valid_batch_size = args.valid_batch_size
|
| 121 |
+
self.learning_rate = args.learning_rate
|
| 122 |
+
self.adam_epsilon = args.adam_epsilon
|
| 123 |
+
|
| 124 |
+
def run(self):
|
| 125 |
+
if self.framework == "tf":
|
| 126 |
+
return self.run_tf()
|
| 127 |
+
return self.run_torch()
|
| 128 |
+
|
| 129 |
+
def run_torch(self):
|
| 130 |
+
raise NotImplementedError
|
| 131 |
+
|
| 132 |
+
def run_tf(self):
|
| 133 |
+
self.pipeline.fit(
|
| 134 |
+
self.train_dataset,
|
| 135 |
+
validation_data=self.valid_dataset,
|
| 136 |
+
validation_split=self.validation_split,
|
| 137 |
+
learning_rate=self.learning_rate,
|
| 138 |
+
adam_epsilon=self.adam_epsilon,
|
| 139 |
+
train_batch_size=self.train_batch_size,
|
| 140 |
+
valid_batch_size=self.valid_batch_size,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Save trained pipeline
|
| 144 |
+
self.pipeline.save_pretrained(self.output)
|
src/transformers/commands/user.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from getpass import getpass
|
| 5 |
+
from typing import List, Union
|
| 6 |
+
|
| 7 |
+
from requests.exceptions import HTTPError
|
| 8 |
+
|
| 9 |
+
from transformers.commands import BaseTransformersCLICommand
|
| 10 |
+
from transformers.hf_api import HfApi, HfFolder
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
UPLOAD_MAX_FILES = 15
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class UserCommands(BaseTransformersCLICommand):
|
| 17 |
+
@staticmethod
|
| 18 |
+
def register_subcommand(parser: ArgumentParser):
|
| 19 |
+
login_parser = parser.add_parser("login", help="Log in using the same credentials as on huggingface.co")
|
| 20 |
+
login_parser.set_defaults(func=lambda args: LoginCommand(args))
|
| 21 |
+
whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.")
|
| 22 |
+
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
|
| 23 |
+
logout_parser = parser.add_parser("logout", help="Log out")
|
| 24 |
+
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
|
| 25 |
+
# s3
|
| 26 |
+
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
|
| 27 |
+
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
|
| 28 |
+
ls_parser = s3_subparsers.add_parser("ls")
|
| 29 |
+
ls_parser.set_defaults(func=lambda args: ListObjsCommand(args))
|
| 30 |
+
rm_parser = s3_subparsers.add_parser("rm")
|
| 31 |
+
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
|
| 32 |
+
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
|
| 33 |
+
# upload
|
| 34 |
+
upload_parser = parser.add_parser("upload")
|
| 35 |
+
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
|
| 36 |
+
upload_parser.add_argument(
|
| 37 |
+
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
|
| 38 |
+
)
|
| 39 |
+
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ANSI:
|
| 43 |
+
"""
|
| 44 |
+
Helper for en.wikipedia.org/wiki/ANSI_escape_code
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
_bold = "\u001b[1m"
|
| 48 |
+
_reset = "\u001b[0m"
|
| 49 |
+
|
| 50 |
+
@classmethod
|
| 51 |
+
def bold(cls, s):
|
| 52 |
+
return "{}{}{}".format(cls._bold, s, cls._reset)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BaseUserCommand:
|
| 56 |
+
def __init__(self, args):
|
| 57 |
+
self.args = args
|
| 58 |
+
self._api = HfApi()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LoginCommand(BaseUserCommand):
|
| 62 |
+
def run(self):
|
| 63 |
+
print(
|
| 64 |
+
"""
|
| 65 |
+
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
|
| 66 |
+
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
| 67 |
+
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
|
| 68 |
+
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
|
| 69 |
+
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
)
|
| 73 |
+
username = input("Username: ")
|
| 74 |
+
password = getpass()
|
| 75 |
+
try:
|
| 76 |
+
token = self._api.login(username, password)
|
| 77 |
+
except HTTPError as e:
|
| 78 |
+
# probably invalid credentials, display error message.
|
| 79 |
+
print(e)
|
| 80 |
+
exit(1)
|
| 81 |
+
HfFolder.save_token(token)
|
| 82 |
+
print("Login successful")
|
| 83 |
+
print("Your token:", token, "\n")
|
| 84 |
+
print("Your token has been saved to", HfFolder.path_token)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class WhoamiCommand(BaseUserCommand):
|
| 88 |
+
def run(self):
|
| 89 |
+
token = HfFolder.get_token()
|
| 90 |
+
if token is None:
|
| 91 |
+
print("Not logged in")
|
| 92 |
+
exit()
|
| 93 |
+
try:
|
| 94 |
+
user = self._api.whoami(token)
|
| 95 |
+
print(user)
|
| 96 |
+
except HTTPError as e:
|
| 97 |
+
print(e)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class LogoutCommand(BaseUserCommand):
|
| 101 |
+
def run(self):
|
| 102 |
+
token = HfFolder.get_token()
|
| 103 |
+
if token is None:
|
| 104 |
+
print("Not logged in")
|
| 105 |
+
exit()
|
| 106 |
+
HfFolder.delete_token()
|
| 107 |
+
self._api.logout(token)
|
| 108 |
+
print("Successfully logged out.")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class ListObjsCommand(BaseUserCommand):
|
| 112 |
+
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
| 113 |
+
"""
|
| 114 |
+
Inspired by:
|
| 115 |
+
stackoverflow.com/a/8356620/593036
|
| 116 |
+
stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
|
| 117 |
+
"""
|
| 118 |
+
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
|
| 119 |
+
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
|
| 120 |
+
lines = []
|
| 121 |
+
lines.append(row_format.format(*headers))
|
| 122 |
+
lines.append(row_format.format(*["-" * w for w in col_widths]))
|
| 123 |
+
for row in rows:
|
| 124 |
+
lines.append(row_format.format(*row))
|
| 125 |
+
return "\n".join(lines)
|
| 126 |
+
|
| 127 |
+
def run(self):
|
| 128 |
+
token = HfFolder.get_token()
|
| 129 |
+
if token is None:
|
| 130 |
+
print("Not logged in")
|
| 131 |
+
exit(1)
|
| 132 |
+
try:
|
| 133 |
+
objs = self._api.list_objs(token)
|
| 134 |
+
except HTTPError as e:
|
| 135 |
+
print(e)
|
| 136 |
+
exit(1)
|
| 137 |
+
if len(objs) == 0:
|
| 138 |
+
print("No shared file yet")
|
| 139 |
+
exit()
|
| 140 |
+
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
| 141 |
+
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class DeleteObjCommand(BaseUserCommand):
|
| 145 |
+
def run(self):
|
| 146 |
+
token = HfFolder.get_token()
|
| 147 |
+
if token is None:
|
| 148 |
+
print("Not logged in")
|
| 149 |
+
exit(1)
|
| 150 |
+
try:
|
| 151 |
+
self._api.delete_obj(token, filename=self.args.filename)
|
| 152 |
+
except HTTPError as e:
|
| 153 |
+
print(e)
|
| 154 |
+
exit(1)
|
| 155 |
+
print("Done")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class UploadCommand(BaseUserCommand):
|
| 159 |
+
def walk_dir(self, rel_path):
|
| 160 |
+
"""
|
| 161 |
+
Recursively list all files in a folder.
|
| 162 |
+
"""
|
| 163 |
+
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
| 164 |
+
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
|
| 165 |
+
for f in entries:
|
| 166 |
+
if f.is_dir():
|
| 167 |
+
files += self.walk_dir(f.path)
|
| 168 |
+
return files
|
| 169 |
+
|
| 170 |
+
def run(self):
|
| 171 |
+
token = HfFolder.get_token()
|
| 172 |
+
if token is None:
|
| 173 |
+
print("Not logged in")
|
| 174 |
+
exit(1)
|
| 175 |
+
local_path = os.path.abspath(self.args.path)
|
| 176 |
+
if os.path.isdir(local_path):
|
| 177 |
+
if self.args.filename is not None:
|
| 178 |
+
raise ValueError("Cannot specify a filename override when uploading a folder.")
|
| 179 |
+
rel_path = os.path.basename(local_path)
|
| 180 |
+
files = self.walk_dir(rel_path)
|
| 181 |
+
elif os.path.isfile(local_path):
|
| 182 |
+
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
|
| 183 |
+
files = [(local_path, filename)]
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError("Not a valid file or directory: {}".format(local_path))
|
| 186 |
+
|
| 187 |
+
if sys.platform == "win32":
|
| 188 |
+
files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files]
|
| 189 |
+
|
| 190 |
+
if len(files) > UPLOAD_MAX_FILES:
|
| 191 |
+
print(
|
| 192 |
+
"About to upload {} files to S3. This is probably wrong. Please filter files before uploading.".format(
|
| 193 |
+
ANSI.bold(len(files))
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
exit(1)
|
| 197 |
+
|
| 198 |
+
for filepath, filename in files:
|
| 199 |
+
print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
|
| 200 |
+
|
| 201 |
+
choice = input("Proceed? [Y/n] ").lower()
|
| 202 |
+
if not (choice == "" or choice == "y" or choice == "yes"):
|
| 203 |
+
print("Abort")
|
| 204 |
+
exit()
|
| 205 |
+
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
| 206 |
+
for filepath, filename in files:
|
| 207 |
+
access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
|
| 208 |
+
print("Your file now lives at:")
|
| 209 |
+
print(access_url)
|