Spaces:
Configuration error
Configuration error
nengrenjie83
commited on
Commit
•
3ab2ab4
1
Parent(s):
44329ff
Upload 42 files
Browse files- .gitattributes +1 -0
- MedicalGPT-main/.gitignore +129 -0
- MedicalGPT-main/CITATION.cff +9 -0
- MedicalGPT-main/CONTRIBUTING.md +9 -0
- MedicalGPT-main/DISCLAIMER +23 -0
- MedicalGPT-main/LICENSE +201 -0
- MedicalGPT-main/README.md +326 -0
- MedicalGPT-main/README_EN.md +224 -0
- MedicalGPT-main/_config.yml +1 -0
- MedicalGPT-main/build_domain_tokenizer.py +59 -0
- MedicalGPT-main/convert_dataset.py +47 -0
- MedicalGPT-main/data/finetune/medical_sft_1K_format.jsonl +0 -0
- MedicalGPT-main/data/finetune/sharegpt_zh_1K_format.jsonl +0 -0
- MedicalGPT-main/data/pretrain/en_article_tail500.txt +500 -0
- MedicalGPT-main/data/pretrain/fever.txt +0 -0
- MedicalGPT-main/data/pretrain/tianlongbabu.txt +0 -0
- MedicalGPT-main/data/reward/test.json +0 -0
- MedicalGPT-main/data/vocab/baichuan_vocab.txt +0 -0
- MedicalGPT-main/data/vocab/word_freq.txt +0 -0
- MedicalGPT-main/deepspeed_config.json +43 -0
- MedicalGPT-main/docs/GPT_Training.jpg +0 -0
- MedicalGPT-main/docs/demo-screen.gif +0 -0
- MedicalGPT-main/docs/dpo.jpg +0 -0
- MedicalGPT-main/docs/logo.png +3 -0
- MedicalGPT-main/docs/training_details.md +104 -0
- MedicalGPT-main/docs/wechat.jpeg +0 -0
- MedicalGPT-main/dpo_training.py +495 -0
- MedicalGPT-main/gradio_demo.py +215 -0
- MedicalGPT-main/inference.py +225 -0
- MedicalGPT-main/merge_peft_adapter.py +109 -0
- MedicalGPT-main/merge_tokenizers.py +150 -0
- MedicalGPT-main/pretraining.py +678 -0
- MedicalGPT-main/requirements.txt +10 -0
- MedicalGPT-main/reward_modeling.py +643 -0
- MedicalGPT-main/rl_training.py +499 -0
- MedicalGPT-main/run_dpo.sh +29 -0
- MedicalGPT-main/run_pt.sh +42 -0
- MedicalGPT-main/run_rl.sh +24 -0
- MedicalGPT-main/run_rm.sh +39 -0
- MedicalGPT-main/run_sft.sh +40 -0
- MedicalGPT-main/run_training_dpo_pipeline.ipynb +711 -0
- MedicalGPT-main/run_training_pipeline.ipynb +917 -0
- MedicalGPT-main/supervised_finetuning.py +927 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
MedicalGPT-main/docs/logo.png filter=lfs diff=lfs merge=lfs -text
|
MedicalGPT-main/.gitignore
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
MedicalGPT-main/CITATION.cff
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
message: "If you use this software, please cite it as below."
|
3 |
+
authors:
|
4 |
+
- family-names: "Xu"
|
5 |
+
given-names: "Ming"
|
6 |
+
title: "MedicalGPT: Training Your Own Medical GPT Model with ChatGPT Training Pipeline"
|
7 |
+
url: "https://github.com/shibing624/MedicalGPT"
|
8 |
+
data-released: 2023-06-02
|
9 |
+
version: 0.0.4
|
MedicalGPT-main/CONTRIBUTING.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing
|
2 |
+
|
3 |
+
We are happy to accept your contributions to make this repo better and more awesome! To avoid unnecessary work on either
|
4 |
+
side, please stick to the following process:
|
5 |
+
|
6 |
+
1. Check if there is already an issue for your concern.
|
7 |
+
2. If there is not, open a new one to start a discussion. We hate to close finished PRs!
|
8 |
+
3. If we decide your concern needs code changes, we would be happy to accept a pull request. Please consider the
|
9 |
+
commit guidelines below.
|
MedicalGPT-main/DISCLAIMER
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The software project, data, and models provided by our GitHub project are provided "as is," without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement.
|
2 |
+
|
3 |
+
In no event shall the project owners or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software project, data, or models, even if advised of the possibility of such damage.
|
4 |
+
|
5 |
+
Users of this software project, data, and models are solely responsible for any consequences of their use. The project owners and contributors shall not be held responsible for any subsequent or potential harm caused by the use of this software project, data, or models.
|
6 |
+
|
7 |
+
By using this software project, data, or models, users accept and agree to this disclaimer. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
|
8 |
+
|
9 |
+
It is important to note that this software project, data, and models are still in the research phase and are provided for experimental purposes only. As such, the project owners and contributors do not guarantee the accuracy, completeness, or usefulness of the software project, data, or models.
|
10 |
+
|
11 |
+
Furthermore, due to the experimental nature of this software project, data, and models, it is possible that they may contain or generate inappropriate responses, errors, or inconsistencies. Users should exercise caution when using this software project, data, or models, and should not rely solely on them for any critical or sensitive tasks.
|
12 |
+
|
13 |
+
The project owners and contributors shall not be held responsible for any damages, losses, or liabilities arising from the use of this software project, data, or models, including but not limited to, any inappropriate responses generated by the software project, data, or models.
|
14 |
+
|
15 |
+
By using this software project, data, or models, users acknowledge and accept the experimental nature of the software project, data, and models, and understand the potential risks and limitations associated with their use. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
|
16 |
+
|
17 |
+
The software project, data, and models provided by our GitHub project are intended for research purposes only. They should not be used for any commercial, business, or legal purposes, and should not be relied upon as a substitute for professional advice or judgment.
|
18 |
+
|
19 |
+
Users of this software project, data, and models are strictly prohibited from using them for any commercial purposes, including but not limited to, selling, licensing, or distributing the software project, data, or models to third parties.
|
20 |
+
|
21 |
+
The project owners and contributors shall not be held responsible for any damages, losses, or liabilities arising from the use of this software project, data, or models for any commercial or business purposes.
|
22 |
+
|
23 |
+
By using this software project, data, or models, users agree to use them for research purposes only, and not for any commercial or business purposes. If users do not agree to the terms of this disclaimer, they should not use this software project, data, or models.
|
MedicalGPT-main/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
MedicalGPT-main/README.md
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[**🇨🇳中文**](https://github.com/shibing624/MedicalGPT/blob/main/README.md) | [**🌐English**](https://github.com/shibing624/MedicalGPT/blob/main/README_EN.md) | [**📖文档/Docs**](https://github.com/shibing624/MedicalGPT/wiki) | [**🤖模型/Models**](https://huggingface.co/shibing624)
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<a href="https://github.com/shibing624/MedicalGPT">
|
5 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/logo.png" height="100" alt="Logo">
|
6 |
+
</a>
|
7 |
+
</div>
|
8 |
+
|
9 |
+
-----------------
|
10 |
+
|
11 |
+
# MedicalGPT: Training Medical GPT Model
|
12 |
+
[![HF Models](https://img.shields.io/badge/Hugging%20Face-shibing624-green)](https://huggingface.co/shibing624)
|
13 |
+
[![Github Stars](https://img.shields.io/github/stars/shibing624/MedicalGPT?color=yellow)](https://star-history.com/#shibing624/MedicalGPT&Timeline)
|
14 |
+
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
|
15 |
+
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
|
16 |
+
[![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt)
|
17 |
+
[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
18 |
+
[![Wechat Group](http://vlog.sfyc.ltd/wechat_everyday/wxgroup_logo.png?imageView2/0/w/60/h/20)](#Contact)
|
19 |
+
|
20 |
+
## 📖 Introduction
|
21 |
+
|
22 |
+
**MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
|
23 |
+
Supervised Finetuning, RLHF(Reward Modeling and Reinforcement Learning) and DPO(Direct Preference Optimization).
|
24 |
+
|
25 |
+
**MedicalGPT** 训练医疗大模型,实现了包括增量预训练、有监督微调、RLHF(奖励建模、强化学习训练)和DPO(直接偏好优化)。
|
26 |
+
|
27 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/dpo.jpg" width="860" />
|
28 |
+
|
29 |
+
- RLHF training pipeline来自Andrej Karpathy的演讲PDF [State of GPT](https://karpathy.ai/stateofgpt.pdf),视频 [Video](https://build.microsoft.com/en-US/sessions/db3f4859-cd30-4445-a0cd-553c3304f8e2)
|
30 |
+
- DPO方法来自论文[Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
|
31 |
+
|
32 |
+
## 🔥 News
|
33 |
+
[2023/08/28] v1.5版本: 新增[DPO(直接偏好优化)](https://arxiv.org/pdf/2305.18290.pdf)方法,DPO通过直接优化语言模型来实现对其行为的精确控制,可以有效学习到人类偏好。详见[Release-v1.5](https://github.com/shibing624/MedicalGPT/releases/tag/1.5.0)
|
34 |
+
|
35 |
+
[2023/08/08] v1.4版本: 发布基于ShareGPT4数据集微调的中英文Vicuna-13B模型[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat),和对应的LoRA模型[shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora),详见[Release-v1.4](https://github.com/shibing624/MedicalGPT/releases/tag/1.4.0)
|
36 |
+
|
37 |
+
[2023/08/02] v1.3版本: 新增LLaMA, LLaMA2, Bloom, ChatGLM, ChatGLM2, Baichuan模型的多轮对话微调训练;新增领域词表扩充功能;新增中文预训练数据集和中文ShareGPT微调训练集,详见[Release-v1.3](https://github.com/shibing624/MedicalGPT/releases/tag/1.3.0)
|
38 |
+
|
39 |
+
[2023/07/13] v1.1版本: 发布中文医疗LLaMA-13B模型[shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged),基于Ziya-LLaMA-13B-v1模型,SFT微调了一版医疗模型,医疗问答效果有提升,发布微调后的完整模型权重,详见[Release-v1.1](https://github.com/shibing624/MedicalGPT/releases/tag/1.1)
|
40 |
+
|
41 |
+
[2023/06/15] v1.0版本: 发布中文医疗LoRA模型[shibing624/ziya-llama-13b-medical-lora](https://huggingface.co/shibing624/ziya-llama-13b-medical-lora),基于Ziya-LLaMA-13B-v1模型,SFT微调了一版医疗模型,医疗问答效果有提升,发布微调后的LoRA权重,详见[Release-v1.0](https://github.com/shibing624/MedicalGPT/releases/tag/1.0.0)
|
42 |
+
|
43 |
+
[2023/06/05] v0.2版本: 以医疗为例,训练领域大模型,实现了四阶段训练:包括二次预训练、有监督微调、奖励建模、强化学习训练。详见[Release-v0.2](https://github.com/shibing624/MedicalGPT/releases/tag/0.2.0)
|
44 |
+
|
45 |
+
|
46 |
+
## 😊 Features
|
47 |
+
|
48 |
+
|
49 |
+
基于ChatGPT Training Pipeline,本项目实现了领域模型--医疗行业语言大模型的训练:
|
50 |
+
|
51 |
+
|
52 |
+
- 第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练GPT模型,以注入领域知识(可选)
|
53 |
+
- 第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图
|
54 |
+
- 第三阶段
|
55 |
+
- RLHF(Reinforcement Learning from Human Feedback)基于人类反馈对语言模型进行强化学习,分为两步:
|
56 |
+
- RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"
|
57 |
+
- RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
|
58 |
+
- [DPO(Direct Preference Optimization)](https://arxiv.org/pdf/2305.18290.pdf)直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
### Release Models
|
63 |
+
|
64 |
+
|
65 |
+
| Model | Base Model | Introduction |
|
66 |
+
|:------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
67 |
+
| [shibing624/ziya-llama-13b-medical-lora](https://huggingface.co/shibing624/ziya-llama-13b-medical-lora) | [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) | 在240万条中英文医疗数据集[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)上SFT微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的LoRA权重(单轮对话) |
|
68 |
+
| [shibing624/ziya-llama-13b-medical-merged](https://huggingface.co/shibing624/ziya-llama-13b-medical-merged) | [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) | 在240万条中英文医疗数据集[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)上SFT微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的完整模型权重(单轮对话) |
|
69 |
+
| [shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora) | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) | 在10万条多语言ShareGPT GPT4多轮对话数据集[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的LoRA权重 |
|
70 |
+
| [shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) | [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) | 在10万条多语言ShareGPT GPT4多轮对话数据集[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的完整模型权重 |
|
71 |
+
|
72 |
+
演示[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat)模型效果:
|
73 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/demo-screen.gif" width="860" />
|
74 |
+
具体case见[Inference Examples](#inference-examples)
|
75 |
+
|
76 |
+
## ▶️ Demo
|
77 |
+
|
78 |
+
|
79 |
+
我们提供了一个简洁的基于gradio的交互式web界面,启动服务后,可通过浏览器访问,输入问题,模型会返回答案。
|
80 |
+
|
81 |
+
启动服务,命令如下:
|
82 |
+
```shell
|
83 |
+
CUDA_VISIBLE_DEVICES=0 python gradio_demo.py --model_type base_model_type --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
|
84 |
+
```
|
85 |
+
|
86 |
+
参数说明:
|
87 |
+
|
88 |
+
- `--model_type {base_model_type}`:预训练模型类型,如llama、bloom、chatglm等
|
89 |
+
- `--base_model {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录,也可使用HF Model Hub模型调用名称
|
90 |
+
- `--lora_model {lora_model}`:LoRA文件所在目录,也可使用HF Model Hub模型调用名称。若lora权重已经合并到预训练模型,则删除--lora_model参数
|
91 |
+
- `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
|
92 |
+
- `--template_name`:模板名称,如`vicuna`、`alpaca`等。若不提供此参数,则其默认值是vicuna
|
93 |
+
- `--only_cpu`: 仅使用CPU进行推理
|
94 |
+
- `--gpus {gpu_ids}`: 指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
|
95 |
+
- `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不调整
|
96 |
+
|
97 |
+
|
98 |
+
## 💾 Install
|
99 |
+
#### Updating the requirements
|
100 |
+
From time to time, the `requirements.txt` changes. To update, use this command:
|
101 |
+
|
102 |
+
```markdown
|
103 |
+
git clone https://github.com/shibing624/MedicalGPT
|
104 |
+
conda activate gpt
|
105 |
+
cd MedicalGPT
|
106 |
+
pip install -r requirements.txt --upgrade
|
107 |
+
```
|
108 |
+
|
109 |
+
## 🚀 Training Pipeline
|
110 |
+
|
111 |
+
Training Stage:
|
112 |
+
|
113 |
+
| Stage | Introduction | Python script | Shell script |
|
114 |
+
|:--------------------------------|:-------------|:--------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------|
|
115 |
+
| Continue Pretraining | 增量预训练 | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |
|
116 |
+
| Supervised Fine-tuning | 有监督微调 | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |
|
117 |
+
| Direct Preference Optimization | 直接偏好优化 | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |
|
118 |
+
| Reward Modeling | 奖励模型建模 | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |
|
119 |
+
| Reinforcement Learning | 强化学习 | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |
|
120 |
+
|
121 |
+
- 提供完整PT+SFT+DPO全阶段串起来训练的pipeline:[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb),运行完大概需要15分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kMIe3pTec2snQvLBA00Br8ND1_zwy3Gr?usp=sharing)
|
122 |
+
- 提供完整PT+SFT+RLHF全阶段串起来训练的pipeline:[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,运行完大概需要20分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RGkbev8D85gR33HJYxqNdnEThODvGUsS?usp=sharing)
|
123 |
+
- [训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
|
124 |
+
- [数据集wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%95%B0%E6%8D%AE%E9%9B%86)
|
125 |
+
- [扩充词表wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%89%A9%E5%85%85%E4%B8%AD%E6%96%87%E8%AF%8D%E8%A1%A8)
|
126 |
+
- [FAQ](https://github.com/shibing624/MedicalGPT/wiki/FAQ)
|
127 |
+
|
128 |
+
#### Supported Models
|
129 |
+
|
130 |
+
| 模型名 | 模型大小 | Template |
|
131 |
+
| ------------------------------------------------------- | --------------------------- |---------------|
|
132 |
+
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | vicuna |
|
133 |
+
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | alpaca |
|
134 |
+
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
135 |
+
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | baichuan-chat |
|
136 |
+
| [InternLM](https://github.com/InternLM/InternLM) | 7B | intern |
|
137 |
+
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | chatml |
|
138 |
+
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | xverse |
|
139 |
+
| [ChatGLM](https://github.com/THUDM/ChatGLM-6B) | 6B | chatglm |
|
140 |
+
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | chatglm2 |
|
141 |
+
|
142 |
+
The following models are tested:
|
143 |
+
|
144 |
+
bloom:
|
145 |
+
- [bigscience/bloomz-560m](https://huggingface.co/bigscience/bloomz-560m)
|
146 |
+
- [bigscience/bloomz-1b7](https://huggingface.co/bigscience/bloomz-1b7)
|
147 |
+
- [bigscience/bloomz-7b1](https://huggingface.co/bigscience/bloomz-7b1)
|
148 |
+
|
149 |
+
llama:
|
150 |
+
- [shibing624/chinese-alpaca-plus-7b-hf](https://huggingface.co/shibing624/chinese-alpaca-plus-7b-hf)
|
151 |
+
- [shibing624/chinese-alpaca-plus-13b-hf](https://huggingface.co/shibing624/chinese-alpaca-plus-13b-hf)
|
152 |
+
- [minlik/chinese-llama-plus-7b-merged](https://huggingface.co/minlik/chinese-llama-plus-7b-merged)
|
153 |
+
- [shibing624/chinese-llama-plus-13b-hf](https://huggingface.co/shibing624/chinese-llama-plus-13b-hf)
|
154 |
+
- [decapoda-research/llama-7b-hf](https://huggingface.co/decapoda-research/llama-7b-hf)
|
155 |
+
- [IDEA-CCNL/Ziya-LLaMA-13B-v1](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1)
|
156 |
+
|
157 |
+
llama2:
|
158 |
+
- [daryl149/llama-2-7b-chat-hf](https://huggingface.co/daryl149/llama-2-7b-chat-hf)
|
159 |
+
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
160 |
+
- [ziqingyang/chinese-alpaca-2-7b](https://huggingface.co/ziqingyang/chinese-alpaca-2-7b)
|
161 |
+
|
162 |
+
chatglm:
|
163 |
+
- [THUDM/chatglm-6b](https://huggingface.co/THUDM/chatglm-6b)
|
164 |
+
- [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b)
|
165 |
+
|
166 |
+
baichuan:
|
167 |
+
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
|
168 |
+
- [baichuan-inc/Baichuan-13B-Base](https://huggingface.co/baichuan-inc/Baichuan-13B-Base)
|
169 |
+
- [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat)
|
170 |
+
|
171 |
+
xverse:
|
172 |
+
- [xverse/XVERSE-13B-Chat](https://huggingface.co/xverse/XVERSE-13B-Chat)
|
173 |
+
|
174 |
+
Qwen:
|
175 |
+
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
|
176 |
+
|
177 |
+
## 💻 Inference
|
178 |
+
训练完成后,现在我们加载训练好的模型,验证模型生成文本的效果。
|
179 |
+
|
180 |
+
```shell
|
181 |
+
CUDA_VISIBLE_DEVICES=0 python inference.py \
|
182 |
+
--model_type base_model_type \
|
183 |
+
--base_model path_to_model_hf_dir \
|
184 |
+
--tokenizer_path path_to_model_hf_dir \
|
185 |
+
--lora_model path_to_lora \
|
186 |
+
--interactive
|
187 |
+
```
|
188 |
+
|
189 |
+
参数说明:
|
190 |
+
|
191 |
+
- `--model_type {base_model_type}`:预训练模型类型,如llama、bloom、chatglm等
|
192 |
+
- `--base_model {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录
|
193 |
+
- `--tokenizer_path {base_model}`:存放HF格式的LLaMA模型权重和配置文件的目录
|
194 |
+
- `--lora_model {lora_model}`:LoRA解压后文件所在目录,也可使用HF Model Hub模型调用名称。如果已经合并了LoRA权重到预训练模型,则可以不提供此参数
|
195 |
+
- `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与--base_model相同
|
196 |
+
- `--template_name`:模板名称,如`vicuna`、`alpaca`等。若不提供此参数,则其默认值是vicuna
|
197 |
+
- `--interactive`:以交互方式启动多轮问答,使用流式推理
|
198 |
+
- `--data_file {file_name}`:非交互方式启动下,按行读取file_name中的的内容进行预测
|
199 |
+
- `--predictions_file {file_name}`:非交互式方式下,将预测的结果以json格式写入file_name
|
200 |
+
- `--resize_emb`:是否调整embedding大小,若不调整,则使用预训练模型的embedding大小,默认不调整
|
201 |
+
- `--only_cpu`:仅使用CPU进行推理
|
202 |
+
- `--gpus {gpu_ids}`:指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2
|
203 |
+
|
204 |
+
|
205 |
+
#### Inference Examples
|
206 |
+
[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) inference examples:
|
207 |
+
|
208 |
+
<details><summary>Model output case</summary>
|
209 |
+
|
210 |
+
| Human | Assitant |
|
211 |
+
|------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
212 |
+
| 小孩发烧怎么办 | 发烧是身体对感染或其他疾病的自然反应。以下是一些可以帮助缓解发烧症状的建议:<br/>1. 保持身体水分充足:发烧时,身体会失去水分,因此需要保持水分充足。<br/>2. 使用冰袋:将冰袋放在额头或颈部可以帮助降低体温。<br/>3. 使用退烧药:退烧药可以帮助降低体温。<br/>4. 保持身体温度:保持身体温度可以帮助降低体温。<br/>5. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。<br/>6. 使用退烧药:退烧药可以帮助降低体温。<br/>7. 保持身体温度:保持身体温度可以帮助降低体温。<br/>8. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。<br/>9. 使用退烧药:退烧药可以帮助降低体温。<br/>10. 保持身体温度:保持身体温度可以帮助降低体温。 |
|
213 |
+
| 一岁宝宝发烧能吃啥药? | 1岁宝宝发烧时,可以考虑以下几种药物:<br/>1. 退烧药:如布洛芬(Ibuprofen)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助降低体温,缓解发烧症状。<br/>2. 抗生素:如果发烧是由细菌感染引起的,可以考虑使用抗生素。但请务必在医生的建议下使用。<br/>3. 抗病毒药物:如果发烧是由病毒感染引起的,可以考虑使用抗病毒药物。但请务必在医生的建议下使用。<br/>4. 解热镇痛药:如阿司匹林(Aspirin)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助缓解发烧症状。<br/>5. 抗组胺药:如苯海拉明(Diphenhydramine)或非索非那定(Naproxen)。这些药物可以帮助缓解发烧症状。<br/>请务必在医生的建议下使用这些药物,并遵循药物说明书上的用量和用法。同时,保持宝宝充足的水分摄入,并确保宝宝有足够的休息。 |
|
214 |
+
| 肛门病变可能是什么疾病的症状 | 肛门病变可能由多种原因引起,包括感染、炎症、肿瘤或其他疾病。一些常见的肛门病变包括:<br/>1. 肛裂:这是一种肛门皮肤撕裂,通常由便秘或其他压力因素引起。<br/>2. 肛瘘:这是一种肛门感染,导致肛门皮肤和直肠之间形成瘘管。<br/>3. 肛门疣:这是一种肛门皮肤上形成的疣状病变,通常由人类乳头状瘤病毒(HPV)引起。<br/>4. 肛门脓肿:这是一种肛门感染,导致肛门皮肤下形成脓肿。<br/>5. 肛门肿瘤:这是一种肛门皮肤或直肠内形成的肿瘤。<br/>6. 肛门溃疡:这是一种肛门皮肤上形成的溃疡,通常由感染或炎症引起。<br/>7. 肛门瘙痒:这是一种肛门皮肤瘙痒的症状,可能由感染、炎症或其他疾病引起。<br/>如果您有肛门病变的症状,请咨询医疗保健专业人员以获得适当的诊断和治疗。 |
|
215 |
+
| 医生,您好白带增多泛黄怎么回事? | 您好,白带增多和发黄可能是由多种原因引起的,包括感染、荷尔蒙变化、阴道干燥或其他妇科问题。以下是一些可能的原因:<br/>1. 感染:白带增多和发黄可能是由阴道或尿道感染引起的。感染可能由细菌、病毒或真菌引起。<br/>2. 荷尔蒙变化:荷尔蒙变化,如月经周期或更年期,可能会导致白带增多和发黄。<br/>3. 阴道干燥:阴道干燥可能会导致白带增多和发黄。<br/>4. 其他妇科问题:其他妇科问题,如子宫颈炎、子宫内膜炎或子宫内膜异位症,也可能导致白带增多和发黄。<br/>如果您担心自己的白带,建议您咨询妇科医生进行检查和诊断。医生可以根据您的症状和检查结果确定原因并提供适当的治疗。 |
|
216 |
+
| 药流后有巧克力白带是没流干净吗 | 药流后出现巧克力白带并不一定意味着没有流干净。巧克力白带是由于子宫内膜脱落和出血引起的,通常在药流后的一段时间内会出现。<br/><br/>药流后出现巧克力白带的原因可能包括:<br/><br/>1. 子宫内膜脱落:药流后,子宫内膜会脱落,导致出血。这种出血可能表现为巧克力白带。<br/>2. 感染:药流后,子宫内膜和宫颈口可能存在创伤,容易受到细菌感染。感染可能导致巧克力白带。<br/>3. 激素水平变化���药流后,激素水平可能发生变化,导致巧克力白带。<br/><br/>如果巧克力白带持续时间较长,或者伴有其他症状,如腹痛、发热、分泌物异味等,建议及时就诊,进行相关检查,如妇科检查、B超等,以确定是否存在感染或其他问题。在医生的指导下进行相应的治疗。 |
|
217 |
+
| 维胺酯维E乳膏能治理什么疾病 | 维胺酯维E乳膏是一种外用药物,主要用于治疗痤疮、脂溢性皮炎等皮肤病。它可以帮助减少皮肤油脂分泌,减轻炎症,促进皮肤修复。<br/><br/>需要注意的是,维胺酯维E乳膏并非万能药,对于某些皮肤病可能并不适用。在使用前,建议咨询医生或药师,了解药物的适应症、禁忌症、用法用量等信息。同时,使用药物时应遵循医嘱,避免自行滥用。 |
|
218 |
+
| 痤疮是什么原因导致的? | 痤疮,也称为粉刺,是由皮肤毛孔堵塞引起的。堵塞毛孔的原因有很多,包括:<br/>1. 皮脂腺过度活跃:皮脂腺产生过多油脂,导致毛孔堵塞。<br/>2. 毛孔堵塞:皮肤表面的死皮细胞、油脂和细菌堵塞毛孔。<br/>3. 炎症:痤疮通常伴随着炎症,这可能是由于皮肤感染或免疫系统对痤疮的反应引起的。<br/>4. 遗传因素:某些人可能更容易患痤疮,这可能是由于遗传因素引起的。<br/>5. 荷尔蒙变化:痤疮在青春期和月经期间更为常见,这可能是由于荷尔蒙变化引起的。<br/>6. 生活方式因素:某些生活方式因素,如压力、不良饮食和缺乏运动,也可能导致痤疮。<br/>7. 化妆品和护肤品:某些化妆品和护肤品可能导致痤疮,特别是那些含有刺激性成分的产品。<br/>8. 药物:某些药物,如口服避孕药,也可能导致痤疮。<br/>要预防痤疮,可以采取一些措施,如保持皮肤清洁、使用温和的护肤品、避免过度使用化妆品和护肤品、保持健康的生活方式和饮食习惯、避免压力和荷尔蒙变化。如果痤疮严重,可能需要寻求皮肤科医生的帮助。 |
|
219 |
+
| 脸上长青春痘怎么办,能吃啥药,需要注意啥 | 青春痘,又称痤疮,是一种常见的皮肤病。青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。以下是一些建议:<br/><br/>1. 饮食:<br/>* 避免高糖、高脂、高油的食物,如快餐、甜点、油炸食品等。<br/>* 多吃新鲜蔬菜、水果、全谷类食物,以增加纤维摄入,促进肠道蠕动。<br/>* 避免摄入过多的咖啡因、酒精、辛辣食物等刺激性食物。<br/><br/>2. 生活习惯:<br/>* 保持良好的作息时间,保证充足的睡眠。<br/>* 避免长时间使用电脑、手机等电子产品,以免辐射对皮肤造成刺激。<br/>* 保持皮肤清洁,每天早晚使用温和的洁面产品清洁皮肤。<br/>* 避免挤压青春痘,以免感染和留下疤痕。<br/><br/>3. 药物:<br/>* 外用药物:可以使用含有水杨酸、苯甲酸、壬二酸等成分的护肤品,帮助去除皮肤表面的角质,减少痘痘的形成。<br/>* 口服药物:如果青春痘严重,可以考虑口服抗生素、维生素B族、锌等药物,但需在医生指导下使用。<br/><br/>4. 注意:<br/>* 不要随意使用化妆品、护肤品,以免刺激皮肤。<br/>* 不要过度清洁皮肤,以免破坏皮肤屏障。<br/>* 保持心情舒畅,避免焦虑、抑郁等负面情绪。<br/><br/>总之,青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。在治疗过程中,要保持耐心和信心,遵循医生的建议,逐步改善皮肤状况。 |
|
220 |
+
| 由于输精管堵塞会不会导致男性不育?,医生您好,我同事这几天不开心,我问她怎么了,她说她老公输精管堵塞,所以一直没孩子,不是不想要,请问输精管堵塞会不会导致男性不育? | 输精管堵塞可能会导致男性不育。输精管是连接睾丸���尿道的一条管道,负责将精子从睾丸输送到尿道。如果输精管堵塞,精子就无法通过输精管到达尿道,从而导致不育。<br/><br/>输精管堵塞的原因有很多,包括先天性发育异常、炎症、肿瘤、外伤等。输精管堵塞的治疗方法包括手术疏通、药物治疗等。如果输精管堵塞严重,可能需要进行手术疏通。<br/><br/>需要注意的是,输精管堵塞并不一定导致不育,因为还有其他途径可以将精子输送到女性体内。例如,精子可以通过附睾、前列腺等组织到达女性体内。因此,如果输精管堵塞,建议及时就医,寻求专业医生的帮助。 |
|
221 |
+
| 心肌梗塞能治疗吗? | 心肌梗死是一种严重的疾病,如果及时治疗,可以挽救生命并防止长期残疾。治疗心肌梗死的关键是尽早识别和治疗。<br/><br/>心肌梗死的治疗通常包括以下步骤:<br/><br/>1. 立即就医:如果出现胸痛、呼吸急促或晕厥等症状,立即就医。<br/><br/>2. 药物治疗:医生可能会开处方药物,如硝酸甘油或阿司匹林,以帮助缓解症状并防止进一步的损伤。<br/><br/>3. 溶栓治疗:如果患者在到达医院前出现心肌梗死,医生可能会进行溶栓治疗,通过注射溶解血栓的药物来帮助恢复血流。<br/><br/>4. 手术治疗:如果药物治疗或溶栓治疗无效,医生可能会建议进行手术,如冠状动脉搭桥手术或心脏支架植入术。<br/><br/>5. 康复和预防:心肌梗死的治疗通常包括康复计划,以帮助患者恢复身体功能并预防未来的心脏事件。此外,患者还应遵循医生的建议,以降低患心脏病的风险。<br/><br/>需要注意的是,心肌梗死的治疗取决于患者的具体情况和症状。因此,如果您或您认识的人出现心肌梗死症状,请立即就医。 |
|
222 |
+
| 怎么让自己精力充沛,列5点建议 | 1. 保持充足的睡眠:睡眠对于保持精力充沛至关重要。确保每晚获得7-9小时的高质量睡眠。<br/>2. 均衡饮食:均衡饮食对于保持能量水平至关重要。确保摄入足够的蛋白质、碳水化合物和健康脂肪。<br/>3. 定期锻炼:定期锻炼可以帮助提高能量水平。尝试每天进行至少30分钟的有氧运动,如散步、跑步或骑自行车。<br/>4. 管理压力:压力会消耗能量,因此管理压力对于保持精力充沛至关重要。尝试找到放松和减压的方法,如冥想、瑜伽或深呼吸练习。<br/>5. 保持水分:脱水会降低能量水平。确保每天喝足够的水,并根据需要补充水分。 |
|
223 |
+
| 写个春天的诗 | 春天,<br/>是大地苏醒的季节,<br/>是万物复苏的季节,<br/>是生命绽放的季节。<br/>春天,<br/>是阳光明媚的季节,<br/>是鸟语花香的季节,<br/>是绿意盎然的季节,<br/>是希望无限的季节。<br/>春天,<br/>是人们欢笑的季节,<br/>是孩子们嬉戏的季节,<br/>是恋人们相拥的季节,<br/>是人们追求梦想的季节。 |
|
224 |
+
| 南极洲为啥没有北极熊? | 南极洲没有北极熊,因为北极熊生活在北极地区,而南极洲位于南半球,气候和环境与北极地区有很大不同。南极洲的气候非常寒冷,常年被冰雪覆盖,而北极熊则生活在北极地区,那里有丰富的海洋生物和陆地哺乳动物。此外,南极洲的生态系统也与北极地区不同,没有北极熊所需的生存条件。 |
|
225 |
+
|
226 |
+
</details>
|
227 |
+
|
228 |
+
## 📚 Dataset
|
229 |
+
### 医疗数据集
|
230 |
+
|
231 |
+
- 240万条中文医疗数据集(包括预训练、指令微调和奖励数据集):[shibing624/medical](https://huggingface.co/datasets/shibing624/medical)
|
232 |
+
- 22万条中文医疗对话数据集(华佗项目):[FreedomIntelligence/HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)
|
233 |
+
|
234 |
+
### 通用数据集
|
235 |
+
|
236 |
+
#### Pretraining datasets
|
237 |
+
- 16GB中英文无监督、平行语料[Linly-AI/Chinese-pretraining-dataset](https://huggingface.co/datasets/Linly-AI/Chinese-pretraining-dataset)
|
238 |
+
- 524MB中文维基百科语料[wikipedia-cn-20230720-filtered](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
239 |
+
#### SFT datasets
|
240 |
+
- 10万条多语言ShareGPT GPT4多轮对话数据集:[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) [本项目支持格式]
|
241 |
+
- 9万条英文ShareGPT多轮对话数集:[anon8231489123/ShareGPT_Vicuna_unfiltered](https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered) [本项目支持格式]
|
242 |
+
- 50万条中文ChatGPT指令Belle数据集:[BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
243 |
+
- 100万条中文ChatGPT指令Belle数据集:[BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
244 |
+
- 5万条英文ChatGPT指令Alpaca数据集:[50k English Stanford Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca#data-release)
|
245 |
+
- 2万条中文ChatGPT指令Alpaca数据集:[shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
|
246 |
+
- 69万条中文指令Guanaco数据集(Belle50万条+Guanaco19万条):[Chinese-Vicuna/guanaco_belle_merge_v1.0](https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0)
|
247 |
+
- 5万条英文ChatGPT多轮对话数据集:[RyokoAI/ShareGPT52K](https://huggingface.co/datasets/RyokoAI/ShareGPT52K)
|
248 |
+
- 80万条中文ChatGPT多轮对话数据集:[BelleGroup/multiturn_chat_0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
249 |
+
- 116万条中文ChatGPT多轮对话数据集:[fnlp/moss-002-sft-data](https://huggingface.co/datasets/fnlp/moss-002-sft-data)
|
250 |
+
- 3.8万条中文ShareGPT多轮对话数据集:[FreedomIntelligence/ShareGPT-CN](https://huggingface.co/datasets/FreedomIntelligence/ShareGPT-CN)
|
251 |
+
|
252 |
+
#### Reward Model datasets
|
253 |
+
- 原版的oasst1数据集:[OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
254 |
+
- 2万条多语言oasst1的reward数据集:[tasksource/oasst1_pairwise_rlhf_reward](https://huggingface.co/datasets/tasksource/oasst1_pairwise_rlhf_reward)[本项目支持格式]
|
255 |
+
- 11万条英文hh-rlhf的reward数据集:[Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf)
|
256 |
+
- 9万条英文reward数据集(来自Anthropic's Helpful Harmless dataset):[Dahoas/static-hh](https://huggingface.co/datasets/Dahoas/static-hh)
|
257 |
+
- 7万条英文reward数据集(来源同上):[Dahoas/rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
258 |
+
- 7万条繁体中文的reward数据集(翻译自rm-static)[liswei/rm-static-m2m100-zh](https://huggingface.co/datasets/liswei/rm-static-m2m100-zh)
|
259 |
+
- 7万条英文Reward数据集:[yitingxie/rlhf-reward-datasets](https://huggingface.co/datasets/yitingxie/rlhf-reward-datasets)
|
260 |
+
- 3千条中文知乎问答偏好数据集:[liyucheng/zhihu_rlhf_3k](https://huggingface.co/datasets/liyucheng/zhihu_rlhf_3k)
|
261 |
+
|
262 |
+
## ✅ Todo
|
263 |
+
|
264 |
+
1. [x] add multi-round dialogue data fine-tuning method
|
265 |
+
2. [x] add reward model fine-tuning
|
266 |
+
3. [x] add rl fine-tuning
|
267 |
+
4. [x] add medical reward dataset
|
268 |
+
5. [x] add llama in8/int4 training
|
269 |
+
6. [x] add all training and predict demo in colab
|
270 |
+
7. [x] add dpo training
|
271 |
+
|
272 |
+
## ☎️ Contact
|
273 |
+
|
274 |
+
- Issue(建议)
|
275 |
+
:[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
276 |
+
- 邮件我:xuming: xuming624@qq.com
|
277 |
+
- 微信我: 加我*微信号:xuming624, 备注:姓名-公司名-NLP* 进NLP交流群。
|
278 |
+
|
279 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/wechat.jpeg" width="200" />
|
280 |
+
|
281 |
+
## ⚠️ 局限性、使用限制与免责声明
|
282 |
+
|
283 |
+
基于当前数据和基础模型训练得到的SFT模型,在效果上仍存在以下问题:
|
284 |
+
|
285 |
+
1. 在涉及事实性的指令上可能会产生违背事实的错误回答。
|
286 |
+
|
287 |
+
2. 对于具备危害性的指令无法很好的鉴别,由此会产生危害性言论。
|
288 |
+
|
289 |
+
3. 在一些涉及推理、代码、多轮对话等场景下模型的能力仍有待提高。
|
290 |
+
|
291 |
+
基于以上模型局限性,我们要求开发者仅将我���开源的模型权重及后续用此项目生成的衍生物用于研究目的,不得用于商业,以及其他会对社会带来危害的用途。
|
292 |
+
|
293 |
+
本项目仅可应用于研究目的,项目开发者不承担任何因使用本项目(包含但不限于数据、模型、代码等)导致的危害或损失。详细请参考[免责声明](https://github.com/shibing624/MedicalGPT/blob/main/DISCLAIMER)。
|
294 |
+
|
295 |
+
项目代码的授权协议为 [The Apache License 2.0](/LICENSE),代码可免费用做商业用途,模型权重和数据只能用于研究目的。请在产品说明中附加MedicalGPT的链接和授权协议。
|
296 |
+
|
297 |
+
|
298 |
+
## 😇 Citation
|
299 |
+
|
300 |
+
如果你在研究中使用了MedicalGPT,请按如下格式引用:
|
301 |
+
|
302 |
+
```latex
|
303 |
+
@misc{MedicalGPT,
|
304 |
+
title={MedicalGPT: Training Medical GPT Model},
|
305 |
+
author={Ming Xu},
|
306 |
+
year={2023},
|
307 |
+
howpublished={\url{https://github.com/shibing624/MedicalGPT}},
|
308 |
+
}
|
309 |
+
```
|
310 |
+
|
311 |
+
## 😍 Contribute
|
312 |
+
|
313 |
+
项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:
|
314 |
+
|
315 |
+
- 在`tests`添加相应的单元测试
|
316 |
+
- 使用`python -m pytest`来运行所有单元测试,确保所有单测都是通过的
|
317 |
+
|
318 |
+
之后即可提交PR。
|
319 |
+
|
320 |
+
## 💕 Acknowledgements
|
321 |
+
|
322 |
+
- [Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
|
323 |
+
- [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
|
324 |
+
- [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
325 |
+
|
326 |
+
Thanks for their great work!
|
MedicalGPT-main/README_EN.md
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[**🇨🇳中文**](https://github.com/shibing624/MedicalGPT/blob/main/README.md) | [**🌐English**](https://github.com/shibing624/MedicalGPT/blob/main/README_EN.md) | [**📖文档/Docs**](https://github.com/shibing624/MedicalGPT/wiki) | [**🤖模型/Models**](https://huggingface.co/shibing624)
|
2 |
+
|
3 |
+
<div align="center">
|
4 |
+
<a href="https://github.com/shibing624/MedicalGPT">
|
5 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/logo.png" width="120" alt="Logo">
|
6 |
+
</a>
|
7 |
+
</div>
|
8 |
+
|
9 |
+
-----------------
|
10 |
+
|
11 |
+
# MedicalGPT: Training Medical GPT Model
|
12 |
+
[![HF Models](https://img.shields.io/badge/Hugging%20Face-shibing624-green)](https://huggingface.co/shibing624)
|
13 |
+
[![Github Stars](https://img.shields.io/github/stars/shibing624/MedicalGPT?color=yellow)](https://star-history.com/#shibing624/MedicalGPT&Timeline)
|
14 |
+
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
|
15 |
+
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
|
16 |
+
[![python_version](https://img.shields.io/badge/Python-3.8%2B-green.svg)](requirements.txt)
|
17 |
+
[![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
18 |
+
[![Wechat Group](http://vlog.sfyc.ltd/wechat_everyday/wxgroup_logo.png?imageView2/0/w/60/h/20)](#Contact)
|
19 |
+
|
20 |
+
## 📖 Introduction
|
21 |
+
|
22 |
+
**MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
|
23 |
+
Supervised Finetuning, Reward Modeling and Reinforcement Learning.
|
24 |
+
|
25 |
+
|
26 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/GPT_Training.jpg" width="860" />
|
27 |
+
|
28 |
+
Training MedicalGPT model:
|
29 |
+
|
30 |
+
- Stage 1:PT(Continue PreTraining), Pre-training the LLaMA model on massive domain document data to inject domain knowledge
|
31 |
+
- Stage 2: SFT (Supervised Fine-tuning) has supervised fine-tuning, constructs instruction fine-tuning data sets, and performs instruction fine-tuning on the basis of pre-trained models to align instruction intentions
|
32 |
+
- Stage 3: RM (Reward Model) reward model modeling, constructing a human preference ranking data set, training the reward model to align human preferences, mainly the "HHH" principle, specifically "helpful, honest, harmless"
|
33 |
+
- Stage 4: RL (Reinforcement Learning) is based on human feedback reinforcement learning (RLHF), using the reward model to train the SFT model, and the generation model uses rewards or penalties to update its strategy in order to generate higher quality, more in line with human preferences text
|
34 |
+
|
35 |
+
## ▶️ Demo
|
36 |
+
|
37 |
+
- Hugging Face Demo: doing
|
38 |
+
|
39 |
+
We provide a simple Gradio-based interactive web interface. After the service is started, it can be accessed through a browser, enter a question, and the model will return an answer. The command is as follows:
|
40 |
+
```shell
|
41 |
+
python scripts/gradio_demo.py --base_model path_to_llama_hf_dir --lora_model path_to_lora_dir
|
42 |
+
```
|
43 |
+
|
44 |
+
Parameter Description:
|
45 |
+
|
46 |
+
- `--base_model {base_model}`: directory to store LLaMA model weights and configuration files in HF format, or use the HF Model Hub model call name
|
47 |
+
- `--lora_model {lora_model}`: The directory where the LoRA file is located, and the name of the HF Model Hub model can also be used. If the lora weights have been merged into the pre-trained model, delete the --lora_model parameter
|
48 |
+
- `--tokenizer_path {tokenizer_path}`: Store the directory corresponding to the tokenizer. If this parameter is not provided, its default value is the same as --lora_model; if the --lora_model parameter is not provided, its default value is the same as --base_model
|
49 |
+
- `--use_cpu`: use only CPU for inference
|
50 |
+
- `--gpus {gpu_ids}`: Specifies the number of GPU devices used, the default is 0. If using multiple GPUs, separate them with commas, such as 0,1,2
|
51 |
+
|
52 |
+
|
53 |
+
## 🚀 Training Pipeline
|
54 |
+
|
55 |
+
### Stage 1: Continue Pretraining
|
56 |
+
|
57 |
+
Based on the llama-7b model, use medical encyclopedia data to continue pre-training, and expect to inject medical knowledge into the pre-training model to obtain the llama-7b-pt model. This step is optional
|
58 |
+
|
59 |
+
|
60 |
+
```shell
|
61 |
+
cd scripts
|
62 |
+
sh run_pt.sh
|
63 |
+
```
|
64 |
+
|
65 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
66 |
+
|
67 |
+
### Stage 2: Supervised FineTuning
|
68 |
+
Based on the llama-7b-pt model, the llama-7b-sft model is obtained by using medical question-and-answer data for supervised fine-tuning. This step is required
|
69 |
+
|
70 |
+
Supervised fine-tuning of the base llama-7b-pt model to create llama-7b-sft
|
71 |
+
|
72 |
+
```shell
|
73 |
+
cd scripts
|
74 |
+
sh run_sft.sh
|
75 |
+
```
|
76 |
+
|
77 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
78 |
+
|
79 |
+
### Stage 3: Reward Modeling
|
80 |
+
RM(Reward Model): reward model modeling
|
81 |
+
|
82 |
+
In principle, we can directly use human annotations to fine-tune the model with RLHF.
|
83 |
+
|
84 |
+
However, this will require us to send some samples to humans to be scored after each round of optimization. This is expensive and slow due to the large number of training samples required for convergence and the limited speed at which humans can read and annotate them.
|
85 |
+
A better strategy than direct feedback is to train a reward model RM on the human annotated set before entering the RL loop. The purpose of the reward model is to simulate human scoring of text.
|
86 |
+
|
87 |
+
The best practice for building a reward model is to rank the prediction results, that is, for each prompt (input text) corresponding to two results (yk, yj), the model predicts which score the human annotation is higher.
|
88 |
+
The RM model is trained by manually marking the scoring results of the SFT model. The purpose is to replace manual scoring. It is essentially a regression model used to align human preferences, mainly based on the "HHH" principle, specifically "helpful, honest, harmless".
|
89 |
+
|
90 |
+
|
91 |
+
Based on the llama-7b-sft model, the reward preference model is trained using medical question and answer preference data, and the llama-7b-reward model is obtained after training. This step is required
|
92 |
+
|
93 |
+
Reward modeling using dialog pairs from the reward dataset using the llama-7b-sft to create llama-7b-reward:
|
94 |
+
|
95 |
+
```shell
|
96 |
+
cd scripts
|
97 |
+
sh run_rm.sh
|
98 |
+
```
|
99 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
100 |
+
|
101 |
+
### Stage 4: Reinforcement Learning
|
102 |
+
The purpose of the RL (Reinforcement Learning) model is to maximize the output of the reward model. Based on the above steps, we have a fine-tuned language model (llama-7b-sft) and reward model (llama-7b-reward).
|
103 |
+
The RL loop is ready to execute.
|
104 |
+
|
105 |
+
This process is roughly divided into three steps:
|
106 |
+
|
107 |
+
1. Enter prompt, the model generates a reply
|
108 |
+
2. Use a reward model to score responses
|
109 |
+
3. Based on the score, a round of reinforcement learning for policy optimization (PPO)
|
110 |
+
|
111 |
+
<img src=https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/trl_loop.png height=400 />
|
112 |
+
|
113 |
+
Reinforcement Learning fine-tuning of llama-7b-sft with the llama-7b-reward reward model to create llama-7b-rl
|
114 |
+
|
115 |
+
```shell
|
116 |
+
cd scripts
|
117 |
+
sh run_rl.sh
|
118 |
+
```
|
119 |
+
[Training Detail wiki](https://github.com/shibing624/MedicalGPT/wiki/Training-Details)
|
120 |
+
|
121 |
+
## 🔥 Inference
|
122 |
+
After the training is complete, now we load the trained model to verify the effect of the model generating text.
|
123 |
+
|
124 |
+
```shell
|
125 |
+
python scripts/inference.py \
|
126 |
+
--base_model path_to_llama_hf_dir \
|
127 |
+
--lora_model path_to_lora \
|
128 |
+
--with_prompt \
|
129 |
+
--interactive
|
130 |
+
```
|
131 |
+
|
132 |
+
Parameter Description:
|
133 |
+
|
134 |
+
- `--base_model {base_model}`: Directory to store LLaMA model weights and configuration files in HF format
|
135 |
+
- `--lora_model {lora_model}`: The directory where the LoRA file is decompressed, and the name of the HF Model Hub model can also be used. If you have incorporated LoRA weights into the pre-trained model, you can not provide this parameter
|
136 |
+
- `--tokenizer_path {tokenizer_path}`: Store the directory corresponding to the tokenizer. If this parameter is not provided, its default value is the same as --lora_model; if the --lora_model parameter is not provided, its default value is the same as --base_model
|
137 |
+
- `--with_prompt`: Whether to merge the input with the prompt template. Be sure to enable this option if loading an Alpaca model!
|
138 |
+
- `--interactive`: start interactively for multiple single rounds of question and answer
|
139 |
+
- `--data_file {file_name}`: Start in non-interactive mode, read the contents of file_name line by line for prediction
|
140 |
+
- `--predictions_file {file_name}`: In non-interactive mode, write the predicted results to file_name in json format
|
141 |
+
- `--use_cpu`: use only CPU for inference
|
142 |
+
- `--gpus {gpu_ids}`: Specifies the number of GPU devices used, the default is 0. If using multiple GPUs, separate them with commas, such as 0,1,2
|
143 |
+
|
144 |
+
|
145 |
+
#### Inference Examples
|
146 |
+
|
147 |
+
<details>
|
148 |
+
| Human | Assistant |
|
149 |
+
| ----- | ------ |
|
150 |
+
| How to make yourself energetic, list 5 suggestions | Here are some suggestions to keep you energetic: <br/>1. Get enough sleep. Make sure to get enough sleep every night. <br/>2. Exercise more. Doing moderate aerobic exercise, such as jogging and swimming, can help improve your energy and stamina. 3. Eat healthy food. Choose protein-rich foods such as chicken, fish, and eggs, as well as fresh fruits and vegetables. 4. Reduce caffeine intake. Try to avoid drinking caffeinated beverages or eating caffeinated foods during the day. <br/>5. Relax. Try meditation, yoga, or other forms of relaxation to reduce stress and anxiety. |
|
151 |
+
</details>
|
152 |
+
<br/>
|
153 |
+
|
154 |
+
|
155 |
+
## 📚 Dataset
|
156 |
+
|
157 |
+
- 2.4 million Chinese medical datasets (including pre-training, instruction fine-tuning and reward datasets): [shibing624/medical](https://huggingface.co/datasets/shibing624/medical)
|
158 |
+
|
159 |
+
**Attach links to some general datasets and medical datasets**
|
160 |
+
|
161 |
+
- Belle dataset of 500,000 Chinese ChatGPT commands: [BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
162 |
+
- Belle dataset of 1 million Chinese ChatGPT commands: [BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
163 |
+
- Alpaca dataset of 50,000 English ChatGPT commands: [50k English Stanford Alpaca dataset](https://github.com/tatsu-lab/stanford_alpaca#data-release)
|
164 |
+
- Alpaca dataset of 20,000 Chinese GPT-4 instructions: [shibing624/alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh)
|
165 |
+
- Guanaco dataset with 690,000 Chinese instructions (500,000 Belle + 190,000 Guanaco): [Chinese-Vicuna/guanaco_belle_merge_v1.0](https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0)
|
166 |
+
- 220,000 Chinese medical dialogue datasets (HuatuoGPT project): [FreedomIntelligence/HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1)
|
167 |
+
|
168 |
+
## ✅ Todo
|
169 |
+
|
170 |
+
1. [ ] Added multi-round dialogue data fine-tuning method
|
171 |
+
2. [x] add reward model finetuning
|
172 |
+
3. [x] add rl finetuning
|
173 |
+
4. [x] add medical reward dataset
|
174 |
+
5. [x] add llama in8/int4 training
|
175 |
+
6. [ ] add all training and predict demo in colab
|
176 |
+
## ☎️ Contact
|
177 |
+
|
178 |
+
- Issue (suggestion)
|
179 |
+
: [![GitHub issues](https://img.shields.io/github/issues/shibing624/MedicalGPT.svg)](https://github.com/shibing624/MedicalGPT/issues)
|
180 |
+
- Email me: xuming: xuming624@qq.com
|
181 |
+
- WeChat Me: Add me* WeChat ID: xuming624, Remarks: Name-Company Name-NLP* Enter the NLP exchange group.
|
182 |
+
|
183 |
+
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/wechat.jpeg" width="200" />
|
184 |
+
|
185 |
+
## ⚠️ Limitations, Restrictions of Use and Disclaimer
|
186 |
+
|
187 |
+
The SFT model trained based on the current data and the basic model still has the following problems in terms of effect:
|
188 |
+
|
189 |
+
1. Wrong answers that contradict the facts may be generated on the factual instructions.
|
190 |
+
2. Unable to identify harmful instructions well, resulting in harmful speech.
|
191 |
+
3. The ability of the model still needs to be improved in some scenarios involving reasoning, code, and multiple rounds of dialogue.
|
192 |
+
|
193 |
+
Based on the limitations of the above models, we require developers to only use our open source model weights and subsequent derivatives generated by this project for research purposes, and not for commercial use, and other purposes that will cause harm to society.
|
194 |
+
This project can only be used for research purposes, and the project developer is not responsible for any harm or loss caused by the use of this project (including but not limited to data, models, codes, etc.). For details, please refer to [Disclaimer](https://github.com/shibing624/MedicalGPT/blob/main/DISCLAIMER).
|
195 |
+
The license agreement for the project code is [The Apache License 2.0](/LICENSE), the code is free for commercial use, and the model weights and data can only be used for research purposes. Please attach MedicalGPT's link and license agreement in the product description.
|
196 |
+
|
197 |
+
## 😇 Citation
|
198 |
+
|
199 |
+
If you used MedicalGPT in your research, please cite as follows:
|
200 |
+
|
201 |
+
```latex
|
202 |
+
@misc{MedicalGPT,
|
203 |
+
title={MedicalGPT: Training Medical GPT Model},
|
204 |
+
author={Ming Xu},
|
205 |
+
year={2023},
|
206 |
+
howpublished={\url{https://github.com/shibing624/MedicalGPT}},
|
207 |
+
}
|
208 |
+
```
|
209 |
+
|
210 |
+
## 😍 Contribute
|
211 |
+
|
212 |
+
The project code is still very rough. If you have improved the code, you are welcome to submit it back to this project. Before submitting, please pay attention to the following two points:
|
213 |
+
|
214 |
+
- Add corresponding unit tests in `tests`
|
215 |
+
- Use `python -m pytest` to run all unit tests to ensure that all unit tests are passed
|
216 |
+
|
217 |
+
Then you can submit a PR.
|
218 |
+
|
219 |
+
## 💕 Acknowledgements
|
220 |
+
|
221 |
+
- [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
|
222 |
+
- [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
223 |
+
|
224 |
+
Thanks for their great work!
|
MedicalGPT-main/_config.yml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
theme: jekyll-theme-cayman
|
MedicalGPT-main/build_domain_tokenizer.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description: Build chinese tokenizer from corpus txt
|
5 |
+
|
6 |
+
# train sentencepiece model from `corpus.txt` and makes `m.model` and `m.vocab`
|
7 |
+
# `m.vocab` is just a reference. not used in the segmentation.
|
8 |
+
# spm.SentencePieceTrainer.train('--input=data/pretrain/tianlongbabu.txt --model_prefix=m --vocab_size=20000')
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
import sentencepiece as spm
|
13 |
+
|
14 |
+
|
15 |
+
def main():
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument('--in_file', default='data/pretrain/fever.txt', type=str)
|
18 |
+
parser.add_argument('--domain_sp_model_name', default='domain_sp', type=str)
|
19 |
+
parser.add_argument('--max_sentence_length', default=16384, type=int)
|
20 |
+
parser.add_argument('--pad_id', default=3, type=int)
|
21 |
+
parser.add_argument('--vocab_size', default=2236, type=int)
|
22 |
+
parser.add_argument('--model_type', default="BPE", type=str)
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
print(args)
|
26 |
+
|
27 |
+
spm.SentencePieceTrainer.train(
|
28 |
+
input=args.in_file,
|
29 |
+
model_prefix=args.domain_sp_model_name,
|
30 |
+
shuffle_input_sentence=False,
|
31 |
+
train_extremely_large_corpus=True,
|
32 |
+
max_sentence_length=args.max_sentence_length,
|
33 |
+
pad_id=args.pad_id,
|
34 |
+
model_type=args.model_type,
|
35 |
+
vocab_size=args.vocab_size,
|
36 |
+
split_digits=True,
|
37 |
+
split_by_unicode_script=True,
|
38 |
+
byte_fallback=True,
|
39 |
+
allow_whitespace_only_pieces=True,
|
40 |
+
remove_extra_whitespaces=False,
|
41 |
+
normalization_rule_name="nfkc",
|
42 |
+
)
|
43 |
+
|
44 |
+
# makes segmenter instance and loads the model file (m.model)
|
45 |
+
sp = spm.SentencePieceProcessor()
|
46 |
+
model_file = args.domain_sp_model_name + '.model'
|
47 |
+
sp.load(model_file)
|
48 |
+
|
49 |
+
# encode: text => id
|
50 |
+
print(sp.encode_as_pieces('潜伏性感染又称潜在性感染。慕容复来到河边,this is a test'))
|
51 |
+
print(sp.encode_as_ids('this is a test'))
|
52 |
+
|
53 |
+
# decode: id => text
|
54 |
+
print(sp.decode_pieces(['▁This', '▁is', '▁a', '▁t', 'est']))
|
55 |
+
# print(sp.decode_ids([209, 31, 9, 375, 586]))
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == '__main__':
|
59 |
+
main()
|
MedicalGPT-main/convert_dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Convert alpaca dataset into sharegpt format.
|
3 |
+
|
4 |
+
Usage: python convert_alpaca.py --in_file alpaca_data.json --out_file alpaca_data_sharegpt.json
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--in_file", type=str)
|
14 |
+
parser.add_argument("--out_file", type=str)
|
15 |
+
parser.add_argument("--data_type", type=str, default='alpaca')
|
16 |
+
args = parser.parse_args()
|
17 |
+
print(args)
|
18 |
+
data_files = {"train": args.in_file}
|
19 |
+
raw_datasets = load_dataset('json', data_files=data_files)
|
20 |
+
ds = raw_datasets['train']
|
21 |
+
|
22 |
+
|
23 |
+
def process_alpaca(examples):
|
24 |
+
convs = []
|
25 |
+
for instruction, inp, output in zip(examples['instruction'], examples['input'], examples['output']):
|
26 |
+
if len(inp.strip()) > 1:
|
27 |
+
instruction = instruction + '\n\n' + inp
|
28 |
+
q = instruction
|
29 |
+
a = output
|
30 |
+
convs.append([
|
31 |
+
{"from": "human", "value": q},
|
32 |
+
{"from": "gpt", "value": a}
|
33 |
+
])
|
34 |
+
return {"conversations": convs}
|
35 |
+
|
36 |
+
|
37 |
+
if args.data_type in ['alpaca']:
|
38 |
+
ds = ds.map(process_alpaca, batched=True, remove_columns=ds.column_names, desc="Running process")
|
39 |
+
else:
|
40 |
+
# Other sharegpt dataset, need rename to conversations and remove unused columns
|
41 |
+
if "items" in ds.column_names:
|
42 |
+
ds = ds.rename(columns={"items": "conversations"})
|
43 |
+
columns_to_remove = ds.column_names.copy()
|
44 |
+
columns_to_remove.remove('conversations')
|
45 |
+
ds = ds.remove_columns(columns_to_remove)
|
46 |
+
|
47 |
+
ds.to_json(f"{args.out_file}", lines=True, force_ascii=False)
|
MedicalGPT-main/data/finetune/medical_sft_1K_format.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/data/finetune/sharegpt_zh_1K_format.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/data/pretrain/en_article_tail500.txt
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
contract to work in specified mines and mills. There seemed to be no
|
2 |
+
limit to the factories, forges, refineries, and railways that could be
|
3 |
+
built, to the multitudes that could be employed in conquering a
|
4 |
+
continent. As for the future, that was in the hands of Providence!
|
5 |
+
|
6 |
+
=Business Theories of Politics.=--As the statesmen of Hamilton's school
|
7 |
+
and the planters of Calhoun's had their theories of government and
|
8 |
+
politics, so the leaders in business enterprise had theirs. It was
|
9 |
+
simple and easily stated. "It is the duty of the government," they
|
10 |
+
urged, "to protect American industry against foreign competition by
|
11 |
+
means of high tariffs on imported goods, to aid railways by generous
|
12 |
+
grants of land, to sell mineral and timber lands at low prices to
|
13 |
+
energetic men ready to develop them, and then to leave the rest to the
|
14 |
+
initiative and drive of individuals and companies." All government
|
15 |
+
interference with the management, prices, rates, charges, and conduct of
|
16 |
+
private business they held to be either wholly pernicious or intolerably
|
17 |
+
impertinent. Judging from their speeches and writings, they conceived
|
18 |
+
the nation as a great collection of individuals, companies, and labor
|
19 |
+
unions all struggling for profits or high wages and held together by a
|
20 |
+
government whose principal duty was to keep the peace among them and
|
21 |
+
protect industry against the foreign manufacturer. Such was the
|
22 |
+
political theory of business during the generation that followed the
|
23 |
+
Civil War.
|
24 |
+
|
25 |
+
|
26 |
+
THE SUPREMACY OF THE REPUBLICAN PARTY (1861-85)
|
27 |
+
|
28 |
+
=Business Men and Republican Policies.=--Most of the leaders in industry
|
29 |
+
gravitated to the Republican ranks. They worked in the North and the
|
30 |
+
Republican party was essentially Northern. It was moreover--at least so
|
31 |
+
far as the majority of its members were concerned--committed to
|
32 |
+
protective tariffs, a sound monetary and banking system, the promotion
|
33 |
+
of railways and industry by land grants, and the development of internal
|
34 |
+
improvements. It was furthermore generous in its immigration policy. It
|
35 |
+
proclaimed America to be an asylum for the oppressed of all countries
|
36 |
+
and flung wide the doors for immigrants eager to fill the factories, man
|
37 |
+
the mines, and settle upon Western lands. In a word the Republicans
|
38 |
+
stood for all those specific measures which favored the enlargement and
|
39 |
+
prosperity of business. At the same time they resisted government
|
40 |
+
interference with private enterprise. They did not regulate railway
|
41 |
+
rates, prosecute trusts for forming combinations, or prevent railway
|
42 |
+
companies from giving lower rates to some shippers than to others. To
|
43 |
+
sum it up, the political theories of the Republican party for three
|
44 |
+
decades after the Civil War were the theories of American
|
45 |
+
business--prosperous and profitable industries for the owners and "the
|
46 |
+
full dinner pail" for the workmen. Naturally a large portion of those
|
47 |
+
who flourished under its policies gave their support to it, voted for
|
48 |
+
its candidates, and subscribed to its campaign funds.
|
49 |
+
|
50 |
+
=Sources of Republican Strength in the North.=--The Republican party was
|
51 |
+
in fact a political organization of singular power. It originated in a
|
52 |
+
wave of moral enthusiasm, having attracted to itself, if not the
|
53 |
+
abolitionists, certainly all those idealists, like James Russell Lowell
|
54 |
+
and George William Curtis, who had opposed slavery when opposition was
|
55 |
+
neither safe nor popular. To moral principles it added practical
|
56 |
+
considerations. Business men had confidence in it. Workingmen, who
|
57 |
+
longed for the independence of the farmer, owed to its indulgent land
|
58 |
+
policy the opportunity of securing free homesteads in the West. The
|
59 |
+
immigrant, landing penniless on these shores, as a result of the same
|
60 |
+
beneficent system, often found himself in a little while with an estate
|
61 |
+
as large as many a baronial domain in the Old World. Under a Republican
|
62 |
+
administration, the union had been saved. To it the veterans of the war
|
63 |
+
could turn with confidence for those rewards of service which the
|
64 |
+
government could bestow: pensions surpassing in liberality anything that
|
65 |
+
the world had ever seen. Under a Republican administration also the
|
66 |
+
great debt had been created in the defense of the union, and to the
|
67 |
+
Republican party every investor in government bonds could look for the
|
68 |
+
full and honorable discharge of the interest and principal. The spoils
|
69 |
+
system, inaugurated by Jacksonian Democracy, in turn placed all the
|
70 |
+
federal offices in Republican hands, furnishing an army of party workers
|
71 |
+
to be counted on for loyal service in every campaign.
|
72 |
+
|
73 |
+
Of all these things Republican leaders made full and vigorous use,
|
74 |
+
sometimes ascribing to the party, in accordance with ancient political
|
75 |
+
usage, merits and achievements not wholly its own. Particularly was this
|
76 |
+
true in the case of saving the union. "When in the economy of
|
77 |
+
Providence, this land was to be purged of human slavery ... the
|
78 |
+
Republican party came into power," ran a declaration in one platform.
|
79 |
+
"The Republican party suppressed a gigantic rebellion, emancipated four
|
80 |
+
million slaves, decreed the equal citizenship of all, and established
|
81 |
+
universal suffrage," ran another. As for the aid rendered by the
|
82 |
+
millions of Northern Democrats who stood by the union and the tens of
|
83 |
+
thousands of them who actually fought in the union army, the Republicans
|
84 |
+
in their zeal were inclined to be oblivious. They repeatedly charged the
|
85 |
+
Democratic party "with being the same in character and spirit as when it
|
86 |
+
sympathized with treason."
|
87 |
+
|
88 |
+
=Republican Control of the South.=--To the strength enjoyed in the
|
89 |
+
North, the Republicans for a long time added the advantages that came
|
90 |
+
from control over the former Confederate states where the newly
|
91 |
+
enfranchised negroes, under white leadership, gave a grateful support to
|
92 |
+
the party responsible for their freedom. In this branch of politics,
|
93 |
+
motives were so mixed that no historian can hope to appraise them all at
|
94 |
+
their proper values. On the one side of the ledger must be set the
|
95 |
+
vigorous efforts of the honest and sincere friends of the freedmen to
|
96 |
+
win for them complete civil and political equality, wiping out not only
|
97 |
+
slavery but all its badges of misery and servitude. On the same side
|
98 |
+
must be placed the labor of those who had valiantly fought in forum and
|
99 |
+
field to save the union and who regarded continued Republican supremacy
|
100 |
+
after the war as absolutely necessary to prevent the former leaders in
|
101 |
+
secession from coming back to power. At the same time there were
|
102 |
+
undoubtedly some men of the baser sort who looked on politics as a game
|
103 |
+
and who made use of "carpet-bagging" in the South to win the spoils that
|
104 |
+
might result from it. At all events, both by laws and presidential acts,
|
105 |
+
the Republicans for many years kept a keen eye upon the maintenance of
|
106 |
+
their dominion in the South. Their declaration that neither the law nor
|
107 |
+
its administration should admit any discrimination in respect of
|
108 |
+
citizens by reason of race, color, or previous condition of servitude
|
109 |
+
appealed to idealists and brought results in elections. Even South
|
110 |
+
Carolina, where reposed the ashes of John C. Calhoun, went Republican in
|
111 |
+
1872 by a vote of three to one!
|
112 |
+
|
113 |
+
Republican control was made easy by the force bills described in a
|
114 |
+
previous chapter--measures which vested the supervision of elections in
|
115 |
+
federal officers appointed by Republican Presidents. These drastic
|
116 |
+
measures, departing from American tradition, the Republican authors
|
117 |
+
urged, were necessary to safeguard the purity of the ballot, not merely
|
118 |
+
in the South where the timid freedman might readily be frightened from
|
119 |
+
using it; but also in the North, particularly in New York City, where it
|
120 |
+
was claimed that fraud was regularly practiced by Democratic leaders.
|
121 |
+
|
122 |
+
The Democrats, on their side, indignantly denied the charges, replying
|
123 |
+
that the force bills were nothing but devices created by the Republicans
|
124 |
+
for the purpose of securing their continued rule through systematic
|
125 |
+
interference with elections. Even the measures of reconstruction were
|
126 |
+
deemed by Democratic leaders as thinly veiled schemes to establish
|
127 |
+
Republican power throughout the country. "Nor is there the slightest
|
128 |
+
doubt," exclaimed Samuel J. Tilden, spokesman of the Democrats in New
|
129 |
+
York and candidate for President in 1876, "that the paramount object and
|
130 |
+
motive of the Republican party is by these means to secure itself
|
131 |
+
against a reaction of opinion adverse to it in our great populous
|
132 |
+
Northern commonwealths.... When the Republican party resolved to
|
133 |
+
establish negro supremacy in the ten states in order to gain to itself
|
134 |
+
the representation of those states in Congress, it had to begin by
|
135 |
+
governing the people of those states by the sword.... The next was the
|
136 |
+
creation of new electoral bodies for those ten states, in which, by
|
137 |
+
exclusions, by disfranchisements and proscriptions, by control over
|
138 |
+
registration, by applying test oaths ... by intimidation and by every
|
139 |
+
form of influence, three million negroes are made to predominate over
|
140 |
+
four and a half million whites."
|
141 |
+
|
142 |
+
=The War as a Campaign Issue.=--Even the repeal of force bills could not
|
143 |
+
allay the sectional feelings engendered by the war. The Republicans
|
144 |
+
could not forgive the men who had so recently been in arms against the
|
145 |
+
union and insisted on calling them "traitors" and "rebels." The
|
146 |
+
Southerners, smarting under the reconstruction acts, could regard the
|
147 |
+
Republicans only as political oppressors. The passions of the war had
|
148 |
+
been too strong; the distress too deep to be soon forgotten. The
|
149 |
+
generation that went through it all remembered it all. For twenty
|
150 |
+
years, the Republicans, in their speeches and platforms, made "a
|
151 |
+
straight appeal to the patriotism of the Northern voters." They
|
152 |
+
maintained that their party, which had saved the union and emancipated
|
153 |
+
the slaves, was alone worthy of protecting the union and uplifting the
|
154 |
+
freedmen.
|
155 |
+
|
156 |
+
Though the Democrats, especially in the North, resented this policy and
|
157 |
+
dubbed it with the expressive but inelegant phrase, "waving the bloody
|
158 |
+
shirt," the Republicans refused to surrender a slogan which made such a
|
159 |
+
ready popular appeal. As late as 1884, a leader expressed the hope that
|
160 |
+
they might "wring one more President from the bloody shirt." They
|
161 |
+
refused to let the country forget that the Democratic candidate, Grover
|
162 |
+
Cleveland, had escaped military service by hiring a substitute; and they
|
163 |
+
made political capital out of the fact that he had "insulted the
|
164 |
+
veterans of the Grand Army of the Republic" by going fishing on
|
165 |
+
Decoration Day.
|
166 |
+
|
167 |
+
=Three Republican Presidents.=--Fortified by all these elements of
|
168 |
+
strength, the Republicans held the presidency from 1869 to 1885. The
|
169 |
+
three Presidents elected in this period, Grant, Hayes, and Garfield, had
|
170 |
+
certain striking characteristics in common. They were all of origin
|
171 |
+
humble enough to please the most exacting Jacksonian Democrat. They had
|
172 |
+
been generals in the union army. Grant, next to Lincoln, was regarded as
|
173 |
+
the savior of the Constitution. Hayes and Garfield, though lesser lights
|
174 |
+
in the military firmament, had honorable records duly appreciated by
|
175 |
+
veterans of the war, now thoroughly organized into the Grand Army of the
|
176 |
+
Republic. It is true that Grant was not a politician and had never voted
|
177 |
+
the Republican ticket; but this was readily overlooked. Hayes and
|
178 |
+
Garfield on the other hand were loyal party men. The former had served
|
179 |
+
in Congress and for three terms as governor of his state. The latter had
|
180 |
+
long been a member of the House of Representatives and was Senator-elect
|
181 |
+
when he received the nomination for President.
|
182 |
+
|
183 |
+
All of them possessed, moreover, another important asset, which was not
|
184 |
+
forgotten by the astute managers who led in selecting candidates. All
|
185 |
+
of them were from Ohio--though Grant had been in Illinois when the
|
186 |
+
summons to military duties came--and Ohio was a strategic state. It lay
|
187 |
+
between the manufacturing East and the agrarian country to the West.
|
188 |
+
Having growing industries and wool to sell it benefited from the
|
189 |
+
protective tariff. Yet being mainly agricultural still, it was not
|
190 |
+
|
191 |
+
without sympathy for the farmers who showed low tariff or free trade
|
192 |
+
tendencies. Whatever share the East had in shaping laws and framing
|
193 |
+
policies, it was clear that the West was to have the candidates. This
|
194 |
+
division in privileges--not uncommon in political management--was always
|
195 |
+
accompanied by a judicious selection of the candidate for Vice
|
196 |
+
President. With Garfield, for example, was associated a prominent New
|
197 |
+
York politician, Chester A. Arthur, who, as fate decreed, was destined
|
198 |
+
to more than three years' service as chief magistrate, on the
|
199 |
+
assassination of his superior in office.
|
200 |
+
|
201 |
+
=The Disputed Election of 1876.=--While taking note of the long years of
|
202 |
+
Republican supremacy, it must be recorded that grave doubts exist in the
|
203 |
+
minds of many historians as to whether one of the three Presidents,
|
204 |
+
Hayes, was actually the victor in 1876 or not. His Democratic opponent,
|
205 |
+
Samuel J. Tilden, received a popular plurality of a quarter of a million
|
206 |
+
and had a plausible claim to a majority of the electoral vote. At all
|
207 |
+
events, four states sent in double returns, one set for Tilden and
|
208 |
+
another for Hayes; and a deadlock ensued. Both parties vehemently
|
209 |
+
claimed the election and the passions ran so high that sober men did not
|
210 |
+
shrink from speaking of civil war again. Fortunately, in the end, the
|
211 |
+
counsels of peace prevailed. Congress provided for an electoral
|
212 |
+
commission of fifteen men to review the contested returns. The
|
213 |
+
Democrats, inspired by Tilden's moderation, accepted the judgment in
|
214 |
+
favor of Hayes even though they were not convinced that he was really
|
215 |
+
entitled to the office.
|
216 |
+
|
217 |
+
|
218 |
+
THE GROWTH OF OPPOSITION TO REPUBLICAN RULE
|
219 |
+
|
220 |
+
=Abuses in American Political Life.=--During their long tenure of
|
221 |
+
office, the Republicans could not escape the inevitable consequences of
|
222 |
+
power; that is, evil practices and corrupt conduct on the part of some
|
223 |
+
who found shelter within the party. For that matter neither did the
|
224 |
+
Democrats manage to avoid such difficulties in those states and cities
|
225 |
+
where they had the majority. In New York City, for instance, the local
|
226 |
+
Democratic organization, known as Tammany Hall, passed under the sway of
|
227 |
+
a group of politicians headed by "Boss" Tweed. He plundered the city
|
228 |
+
treasury until public-spirited citizens, supported by Samuel J. Tilden,
|
229 |
+
the Democratic leader of the state, rose in revolt, drove the ringleader
|
230 |
+
from power, and sent him to jail. In Philadelphia, the local Republican
|
231 |
+
bosses were guilty of offenses as odious as those committed by New York
|
232 |
+
politicians. Indeed, the decade that followed the Civil War was marred
|
233 |
+
by so many scandals in public life that one acute editor was moved to
|
234 |
+
inquire: "Are not all the great communities of the Western World growing
|
235 |
+
more corrupt as they grow in wealth?"
|
236 |
+
|
237 |
+
In the sphere of national politics, where the opportunities were
|
238 |
+
greater, betrayals of public trust were even more flagrant. One
|
239 |
+
revelation after another showed officers, high and low, possessed with
|
240 |
+
the spirit of peculation. Members of Congress, it was found, accepted
|
241 |
+
railway stock in exchange for votes in favor of land grants and other
|
242 |
+
concessions to the companies. In the administration as well as the
|
243 |
+
legislature the disease was rife. Revenue officers permitted whisky
|
244 |
+
distillers to evade their taxes and received heavy bribes in return. A
|
245 |
+
probe into the post-office department revealed the malodorous "star
|
246 |
+
route frauds"--the deliberate overpayment of certain mail carriers whose
|
247 |
+
lines were indicated in the official record by asterisks or stars. Even
|
248 |
+
cabinet officers did not escape suspicion, for the trail of the serpent
|
249 |
+
led straight to the door of one of them.
|
250 |
+
|
251 |
+
In the lower ranges of official life, the spoils system became more
|
252 |
+
virulent as the number of federal employees increased. The holders of
|
253 |
+
offices and the seekers after them constituted a veritable political
|
254 |
+
army. They crowded into Republican councils, for the Republicans, being
|
255 |
+
in power, could alone dispense federal favors. They filled positions in
|
256 |
+
the party ranging from the lowest township committee to the national
|
257 |
+
convention. They helped to nominate candidates and draft platforms and
|
258 |
+
elbowed to one side the busy citizen, not conversant with party
|
259 |
+
intrigues, who could only give an occasional day to political matters.
|
260 |
+
Even the Civil Service Act of 1883, wrung from a reluctant Congress two
|
261 |
+
years after the assassination of Garfield, made little change for a long
|
262 |
+
time. It took away from the spoilsmen a few thousand government
|
263 |
+
positions, but it formed no check on the practice of rewarding party
|
264 |
+
workers from the public treasury.
|
265 |
+
|
266 |
+
On viewing this state of affairs, many a distinguished citizen became
|
267 |
+
profoundly discouraged. James Russell Lowell, for example, thought he
|
268 |
+
saw a steady decline in public morals. In 1865, hearing of Lee's
|
269 |
+
surrender, he had exclaimed: "There is something magnificent in having a
|
270 |
+
country to love!" Ten years later, when asked to write an ode for the
|
271 |
+
centennial at Philadelphia in 1876, he could think only of a biting
|
272 |
+
satire on the nation:
|
273 |
+
|
274 |
+
"Show your state legislatures; show your Rings;
|
275 |
+
And challenge Europe to produce such things
|
276 |
+
As high officials sitting half in sight
|
277 |
+
To share the plunder and fix things right.
|
278 |
+
If that don't fetch her, why, you need only
|
279 |
+
To show your latest style in martyrs,--Tweed:
|
280 |
+
She'll find it hard to hide her spiteful tears
|
281 |
+
At such advance in one poor hundred years."
|
282 |
+
|
283 |
+
When his critics condemned him for this "attack upon his native land,"
|
284 |
+
Lowell replied in sadness: "These fellows have no notion of what love of
|
285 |
+
country means. It was in my very blood and bones. If I am not an
|
286 |
+
American who ever was?... What fills me with doubt and dismay is the
|
287 |
+
degradation of the moral tone. Is it or is it not a result of democracy?
|
288 |
+
Is ours a 'government of the people, by the people, for the people,' or
|
289 |
+
a Kakistocracy [a government of the worst], rather for the benefit of
|
290 |
+
knaves at the cost of fools?"
|
291 |
+
|
292 |
+
=The Reform Movement in Republican Ranks.=--The sentiments expressed by
|
293 |
+
Lowell, himself a Republican and for a time American ambassador to
|
294 |
+
England, were shared by many men in his party. Very soon after the close
|
295 |
+
of the Civil War some of them began to protest vigorously against the
|
296 |
+
policies and conduct of their leaders. In 1872, the dissenters, calling
|
297 |
+
themselves Liberal Republicans, broke away altogether, nominated a
|
298 |
+
candidate of their own, Horace Greeley, and put forward a platform
|
299 |
+
indicting the Republican President fiercely enough to please the most
|
300 |
+
uncompromising Democrat. They accused Grant of using "the powers and
|
301 |
+
opportunities of his high office for the promotion of personal ends."
|
302 |
+
They charged him with retaining "notoriously corrupt and unworthy men in
|
303 |
+
places of power and responsibility." They alleged that the Republican
|
304 |
+
party kept "alive the passions and resentments of the late civil war to
|
305 |
+
use them for their own advantages," and employed the "public service of
|
306 |
+
the government as a machinery of corruption and personal influence."
|
307 |
+
|
308 |
+
It was not apparent, however, from the ensuing election that any
|
309 |
+
considerable number of Republicans accepted the views of the Liberals.
|
310 |
+
Greeley, though indorsed by the Democrats, was utterly routed and died
|
311 |
+
of a broken heart. The lesson of his discomfiture seemed to be that
|
312 |
+
independent action was futile. So, at least, it was regarded by most men
|
313 |
+
of the rising generation like Henry Cabot Lodge, of Massachusetts, and
|
314 |
+
Theodore Roosevelt, of New York. Profiting by the experience of Greeley
|
315 |
+
they insisted in season and out that reformers who desired to rid the
|
316 |
+
party of abuses should remain loyal to it and do their work "on the
|
317 |
+
inside."
|
318 |
+
|
319 |
+
=The Mugwumps and Cleveland Democracy in 1884.=--Though aided by
|
320 |
+
Republican dissensions, the Democrats were slow in making headway
|
321 |
+
against the political current. They were deprived of the energetic and
|
322 |
+
capable leadership once afforded by the planters, like Calhoun, Davis,
|
323 |
+
and Toombs; they were saddled by their opponents with responsibility for
|
324 |
+
secession; and they were stripped of the support of the prostrate
|
325 |
+
South. Not until the last Southern state was restored to the union, not
|
326 |
+
until a general amnesty was wrung from Congress, not until white
|
327 |
+
supremacy was established at the polls, and the last federal soldier
|
328 |
+
withdrawn from Southern capitals did they succeed in capturing the
|
329 |
+
presidency.
|
330 |
+
|
331 |
+
The opportune moment for them came in 1884 when a number of
|
332 |
+
circumstances favored their aspirations. The Republicans, leaving the
|
333 |
+
Ohio Valley in their search for a candidate, nominated James G. Blaine
|
334 |
+
of Maine, a vigorous and popular leader but a man under fire from the
|
335 |
+
reformers in his own party. The Democrats on their side were able to
|
336 |
+
find at this juncture an able candidate who had no political enemies in
|
337 |
+
the sphere of national politics, Grover Cleveland, then governor of New
|
338 |
+
York and widely celebrated as a man of "sterling honesty." At the same
|
339 |
+
time a number of dissatisfied Republicans openly espoused the Democratic
|
340 |
+
cause,--among them Carl Schurz, George William Curtis, Henry Ward
|
341 |
+
Beecher, and William Everett, men of fine ideals and undoubted
|
342 |
+
integrity. Though the "regular" Republicans called them "Mugwumps" and
|
343 |
+
laughed at them as the "men milliners, the dilettanti, and carpet
|
344 |
+
knights of politics," they had a following that was not to be despised.
|
345 |
+
|
346 |
+
The campaign which took place that year was one of the most savage in
|
347 |
+
American history. Issues were thrust into the background. The tariff,
|
348 |
+
though mentioned, was not taken seriously. Abuse of the opposition was
|
349 |
+
the favorite resource of party orators. The Democrats insisted that "the
|
350 |
+
Republican party so far as principle is concerned is a reminiscence. In
|
351 |
+
practice it is an organization for enriching those who control its
|
352 |
+
machinery." For the Republican candidate, Blaine, they could hardly find
|
353 |
+
words to express their contempt. The Republicans retaliated in kind.
|
354 |
+
They praised their own good works, as of old, in saving the union, and
|
355 |
+
denounced the "fraud and violence practiced by the Democracy in the
|
356 |
+
Southern states." Seeing little objectionable in the public record of
|
357 |
+
Cleveland as mayor of Buffalo and governor of New York, they attacked
|
358 |
+
his personal character. Perhaps never in the history of political
|
359 |
+
campaigns did the discussions on the platform and in the press sink to
|
360 |
+
so low a level. Decent people were sickened. Even hot partisans shrank
|
361 |
+
from their own words when, after the election, they had time to reflect
|
362 |
+
on their heedless passions. Moreover, nothing was decided by the
|
363 |
+
balloting. Cleveland was elected, but his victory was a narrow one. A
|
364 |
+
change of a few hundred votes in New York would have sent his opponent
|
365 |
+
to the White House instead.
|
366 |
+
|
367 |
+
=Changing Political Fortunes (1888-96).=--After the Democrats had
|
368 |
+
settled down to the enjoyment of their hard-earned victory, President
|
369 |
+
Cleveland in his message of 1887 attacked the tariff as "vicious,
|
370 |
+
inequitable, and illogical"; as a system of taxation that laid a burden
|
371 |
+
upon "every consumer in the land for the benefit of our manufacturers."
|
372 |
+
Business enterprise was thoroughly alarmed. The Republicans
|
373 |
+
characterized the tariff message as a free-trade assault upon the
|
374 |
+
industries of the country. Mainly on that issue they elected in 1888
|
375 |
+
Benjamin Harrison of Indiana, a shrewd lawyer, a reticent politician, a
|
376 |
+
descendant of the hero of Tippecanoe, and a son of the old Northwest.
|
377 |
+
Accepting the outcome of the election as a vindication of their
|
378 |
+
principles, the Republicans, under the leadership of William McKinley in
|
379 |
+
the House of Representatives, enacted in 1890 a tariff law imposing the
|
380 |
+
highest duties yet laid in our history. To their utter surprise,
|
381 |
+
however, they were instantly informed by the country that their program
|
382 |
+
was not approved. That very autumn they lost in the congressional
|
383 |
+
elections, and two years later they were decisively beaten in the
|
384 |
+
presidential campaign, Cleveland once more leading his party to victory.
|
385 |
+
|
386 |
+
|
387 |
+
=References=
|
388 |
+
|
389 |
+
L.H. Haney, _Congressional History of Railways_ (2 vols.).
|
390 |
+
|
391 |
+
J.P. Davis, _Union Pacific Railway_.
|
392 |
+
|
393 |
+
J.M. Swank, _History of the Manufacture of Iron_.
|
394 |
+
|
395 |
+
M.T. Copeland, _The Cotton Manufacturing Industry in the United States_
|
396 |
+
(Harvard Studies).
|
397 |
+
|
398 |
+
E.W. Bryce, _Progress of Invention in the Nineteenth Century_.
|
399 |
+
|
400 |
+
Ida Tarbell, _History of the Standard Oil Company_ (Critical).
|
401 |
+
|
402 |
+
G.H. Montague, _Rise and Progress of the Standard Oil Company_
|
403 |
+
(Friendly).
|
404 |
+
|
405 |
+
H.P. Fairchild, _Immigration_, and F.J. Warne, _The Immigrant Invasion_
|
406 |
+
(Both works favor exclusion).
|
407 |
+
|
408 |
+
I.A. Hourwich, _Immigration_ (Against exclusionist policies).
|
409 |
+
|
410 |
+
J.F. Rhodes, _History of the United States, 1877-1896_, Vol. VIII.
|
411 |
+
|
412 |
+
Edward Stanwood, _A History of the Presidency_, Vol. I, for the
|
413 |
+
presidential elections of the period.
|
414 |
+
|
415 |
+
|
416 |
+
=Questions=
|
417 |
+
|
418 |
+
1. Contrast the state of industry and commerce at the close of the Civil
|
419 |
+
War with its condition at the close of the Revolutionary War.
|
420 |
+
|
421 |
+
2. Enumerate the services rendered to the nation by the railways.
|
422 |
+
|
423 |
+
3. Explain the peculiar relation of railways to government.
|
424 |
+
|
425 |
+
4. What sections of the country have been industrialized?
|
426 |
+
|
427 |
+
5. How do you account for the rise and growth of the trusts? Explain
|
428 |
+
some of the economic advantages of the trust.
|
429 |
+
|
430 |
+
6. Are the people in cities more or less independent than the farmers?
|
431 |
+
What was Jefferson's view?
|
432 |
+
|
433 |
+
7. State some of the problems raised by unrestricted immigration.
|
434 |
+
|
435 |
+
8. What was the theory of the relation of government to business in this
|
436 |
+
period? Has it changed in recent times?
|
437 |
+
|
438 |
+
9. State the leading economic policies sponsored by the Republican
|
439 |
+
party.
|
440 |
+
|
441 |
+
10. Why were the Republicans especially strong immediately after the
|
442 |
+
Civil War?
|
443 |
+
|
444 |
+
11. What illustrations can you give showing the influence of war in
|
445 |
+
American political campaigns?
|
446 |
+
|
447 |
+
12. Account for the strength of middle-western candidates.
|
448 |
+
|
449 |
+
13. Enumerate some of the abuses that appeared in American political
|
450 |
+
life after 1865.
|
451 |
+
|
452 |
+
14. Sketch the rise and growth of the reform movement.
|
453 |
+
|
454 |
+
15. How is the fluctuating state of public opinion reflected in the
|
455 |
+
elections from 1880 to 1896?
|
456 |
+
|
457 |
+
|
458 |
+
=Research Topics=
|
459 |
+
|
460 |
+
=Invention, Discovery, and Transportation.=--Sparks, _National
|
461 |
+
Development_ (American Nation Series), pp. 37-67; Bogart, _Economic
|
462 |
+
History of the United States_, Chaps. XXI, XXII, and XXIII.
|
463 |
+
|
464 |
+
=Business and Politics.=--Paxson, _The New Nation_ (Riverside Series),
|
465 |
+
pp. 92-107; Rhodes, _History of the United States_, Vol. VII, pp. 1-29,
|
466 |
+
64-73, 175-206; Wilson, _History of the American People_, Vol. IV, pp.
|
467 |
+
78-96.
|
468 |
+
|
469 |
+
=Immigration.=--Coman, _Industrial History of the United States_ (2d
|
470 |
+
ed.), pp. 369-374; E.L. Bogart, _Economic History of the United States_,
|
471 |
+
pp. 420-422, 434-437; Jenks and Lauck, _Immigration Problems_, Commons,
|
472 |
+
_Races and Immigrants_.
|
473 |
+
|
474 |
+
=The Disputed Election of 1876.=--Haworth, _The United States in Our Own
|
475 |
+
Time_, pp. 82-94; Dunning, _Reconstruction, Political and Economic_
|
476 |
+
(American Nation Series), pp. 294-341; Elson, _History of the United
|
477 |
+
States_, pp. 835-841.
|
478 |
+
|
479 |
+
=Abuses in Political Life.=--Dunning, _Reconstruction_, pp. 281-293; see
|
480 |
+
criticisms in party platforms in Stanwood, _History of the Presidency_,
|
481 |
+
Vol. I; Bryce, _American Commonwealth_ (1910 ed.), Vol. II, pp. 379-448;
|
482 |
+
136-167.
|
483 |
+
|
484 |
+
=Studies of Presidential Administrations.=--(_a_) Grant, (_b_) Hayes,
|
485 |
+
(_c_) Garfield-Arthur, (_d_) Cleveland, and (_e_) Harrison, in Haworth,
|
486 |
+
_The United States in Our Own Time_, or in Paxson, _The New Nation_
|
487 |
+
(Riverside Series), or still more briefly in Elson.
|
488 |
+
|
489 |
+
=Cleveland Democracy.=--Haworth, _The United States_, pp. 164-183;
|
490 |
+
Rhodes, _History of the United States_, Vol. VIII, pp. 240-327; Elson,
|
491 |
+
pp. 857-887.
|
492 |
+
|
493 |
+
=Analysis of Modern Immigration Problems.=--_Syllabus in History_ (New
|
494 |
+
York State, 1919), pp. 110-112.
|
495 |
+
|
496 |
+
|
497 |
+
|
498 |
+
|
499 |
+
CHAPTER XVIII
|
500 |
+
|
MedicalGPT-main/data/pretrain/fever.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/data/pretrain/tianlongbabu.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/data/reward/test.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/data/vocab/baichuan_vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/data/vocab/word_freq.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
MedicalGPT-main/deepspeed_config.json
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"optimizer": {
|
3 |
+
"type": "AdamW",
|
4 |
+
"params": {
|
5 |
+
"lr": "auto",
|
6 |
+
"weight_decay": "auto",
|
7 |
+
"torch_adam": true,
|
8 |
+
"adam_w_mode": true
|
9 |
+
}
|
10 |
+
},
|
11 |
+
"scheduler": {
|
12 |
+
"type": "WarmupDecayLR",
|
13 |
+
"params": {
|
14 |
+
"warmup_min_lr": "auto",
|
15 |
+
"warmup_max_lr": "auto",
|
16 |
+
"warmup_num_steps": "auto",
|
17 |
+
"total_num_steps": "auto"
|
18 |
+
}
|
19 |
+
},
|
20 |
+
"fp16": {
|
21 |
+
"enabled": true,
|
22 |
+
"loss_scale": 0,
|
23 |
+
"loss_scale_window": 1000,
|
24 |
+
"initial_scale_power": 16,
|
25 |
+
"hysteresis": 2,
|
26 |
+
"min_loss_scale": 1
|
27 |
+
},
|
28 |
+
"zero_optimization": {
|
29 |
+
"stage": 2,
|
30 |
+
"allgather_partitions": true,
|
31 |
+
"allgather_bucket_size": 2e8,
|
32 |
+
"reduce_scatter": true,
|
33 |
+
"reduce_bucket_size": "auto",
|
34 |
+
"overlap_comm": true,
|
35 |
+
"contiguous_gradients": true
|
36 |
+
},
|
37 |
+
"gradient_accumulation_steps": "auto",
|
38 |
+
"gradient_clipping": "auto",
|
39 |
+
"steps_per_print": 1000,
|
40 |
+
"train_batch_size": "auto",
|
41 |
+
"train_micro_batch_size_per_gpu": "auto",
|
42 |
+
"wall_clock_breakdown": false
|
43 |
+
}
|
MedicalGPT-main/docs/GPT_Training.jpg
ADDED
MedicalGPT-main/docs/demo-screen.gif
ADDED
MedicalGPT-main/docs/dpo.jpg
ADDED
MedicalGPT-main/docs/logo.png
ADDED
Git LFS Details
|
MedicalGPT-main/docs/training_details.md
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training Detail
|
2 |
+
|
3 |
+
|
4 |
+
### Stage 1: PT(Continue PreTraining)
|
5 |
+
第一阶段:PT(Continue PreTraining)增量预训练
|
6 |
+
|
7 |
+
使用百科类文档类数据集,用来在领域数据集上增量预训练或二次预训练,期望能把领域知识注入给模型,以医疗领域为例,希望增量预训练,能让模型理解感冒的症状、病因、治疗药品、治疗方法、药品疗效等知识,便于后续的SFT监督微调能激活这些内在知识。
|
8 |
+
|
9 |
+
这里说明一点,像GPT3、LLaMA这样的大模型理论上是可以从增量预训练中获益,但增量预训练需要满足两个要求:1)高质量的预训练样本;2)较大的计算资源,显存要求高,即使是用LoRA技术,也要满足block_size=1024或2048长度的文本加载到显存中。
|
10 |
+
|
11 |
+
其次,如果你的项目用到的数据是模型预训练中已经使用了的,如维基百科、ArXiv等LLaMA模型预训练用了的,则这些数据是没有必要再喂给LLaMA增量预训练,而且预训练样本的质量如果不够高,也可能会损害原模型的生成能力。
|
12 |
+
|
13 |
+
tips:PT阶段是可选项,请慎重处理。
|
14 |
+
|
15 |
+
基于llama-7b模型,使用医疗百科类数据继续预训练,期望注入医疗知识到预训练模型,得到llama-7b-pt模型
|
16 |
+
|
17 |
+
Continue pretraining of the base llama-7b model to create llama-7b-pt:
|
18 |
+
|
19 |
+
```shell
|
20 |
+
cd scripts
|
21 |
+
sh run_pt.sh
|
22 |
+
```
|
23 |
+
|
24 |
+
[训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82%E8%AF%B4%E6%98%8E)
|
25 |
+
- 如果你的显存不足,可以改小batch_size=1, block_size=512(影响训练的上下文最大长度);
|
26 |
+
- 如果你的显存更大,可以改大block_size=2048, 此为llama原始预训练长度,不能更大啦;调大batch_size。
|
27 |
+
|
28 |
+
### Stage 2: SFT(Supervised Fine-tuning)
|
29 |
+
第二阶段:SFT(Supervised Fine-tuning)有监督微调
|
30 |
+
|
31 |
+
基于llama-7b-pt模型,使用医疗问答类数据进行有监督微调,得到llama-7b-sft模型
|
32 |
+
|
33 |
+
Supervised fine-tuning of the base llama-7b-pt model to create llama-7b-sft
|
34 |
+
|
35 |
+
```shell
|
36 |
+
cd scripts
|
37 |
+
sh run_sft.sh
|
38 |
+
```
|
39 |
+
|
40 |
+
[训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82%E8%AF%B4%E6%98%8E)
|
41 |
+
|
42 |
+
### Stage 3: RLHF(Reinforcement Learning from Human Feedback)
|
43 |
+
#### Reward Modeling
|
44 |
+
RM(Reward Model)奖励模型建模
|
45 |
+
|
46 |
+
RM(Reward Model)奖励模型,原则上,我们可以直接用人类标注来对模型做 RLHF 微调。
|
47 |
+
|
48 |
+
然而,这将需要我们给人类发送一些样本,在每轮优化后计分。这是贵且慢的,因为收敛需要的训练样本量大,而人类阅读和标注的速度有限。
|
49 |
+
一个比直接反馈更好的策略是,在进入 RL 循环之前用人类标注集来训练一个奖励模型RM。奖励模型的目的是模拟人类对文本的打分。
|
50 |
+
|
51 |
+
构建奖励模型的最佳实践是预测结果的排序,即对每个 prompt (输入文本) 对应的两个结果 (yk, yj),模型预测人类标注的比分哪个更高。
|
52 |
+
RM模型是通过人工标注SFT模型的打分结果来训练的,目的是取代人工打分,本质是个回归模型,用来对齐人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"。
|
53 |
+
|
54 |
+
|
55 |
+
基于llama-7b-sft模型,使用医疗问答偏好数据训练奖励偏好模型,训练得到llama-7b-reward模型
|
56 |
+
|
57 |
+
Reward modeling using dialog pairs from the reward dataset using the llama-7b-sft to create llama-7b-reward:
|
58 |
+
|
59 |
+
```shell
|
60 |
+
cd scripts
|
61 |
+
sh run_rm.sh
|
62 |
+
```
|
63 |
+
[训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82%E8%AF%B4%E6%98%8E)
|
64 |
+
|
65 |
+
#### Reinforcement Learning
|
66 |
+
RL(Reinforcement Learning)强化学习
|
67 |
+
|
68 |
+
RL(Reinforcement Learning)模型的目的是最大化奖励模型的输出,基于上面步骤,我们有了微调的语言模型(llama-7b-sft)和奖励模型(llama-7b-reward),
|
69 |
+
可以开始执行 RL 循环了。
|
70 |
+
|
71 |
+
这个过程大致分为三步:
|
72 |
+
|
73 |
+
1. 输入prompt,模型生成答复
|
74 |
+
2. 用奖励模型来对答复评分
|
75 |
+
3. 基于评分,进行一轮策略优化的强化学习(PPO)
|
76 |
+
|
77 |
+
<img src=https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/trl_loop.png height=400 />
|
78 |
+
|
79 |
+
|
80 |
+
基于llama-7b-reward模型 RL 微调训练llama-7b-sft模型,得到llama-7b-rl模型
|
81 |
+
|
82 |
+
Reinforcement Learning fine-tuning of llama-7b-sft with the llama-7b-reward reward model to create llama-7b-rl
|
83 |
+
|
84 |
+
```shell
|
85 |
+
pip install git+https://github.com/lvwerra/trl
|
86 |
+
cd scripts
|
87 |
+
sh run_rl.sh
|
88 |
+
```
|
89 |
+
|
90 |
+
### Stage 3: DPO(Direct Preference Optimization)
|
91 |
+
DPO(Direct Preference Optimization)直接偏好优化
|
92 |
+
|
93 |
+
DPO方法可以通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习。
|
94 |
+
|
95 |
+
DPO 将奖励函数和最优策略之间的映射联系起来,从而把约束奖励最大化问题转化为一个单阶段的策略训练问题。
|
96 |
+
这种算法不仅不用拟合奖励模型,还避免了在微调过程中从语言模型中采样或调整超参数的需要。
|
97 |
+
|
98 |
+
实验结果表明,DPO 算法可以与现有RLHF方法一样有效地从人类偏好中学习,甚至在某些任务中表现更好,比如情感调节、摘要和单轮对话。
|
99 |
+
|
100 |
+
PS: 使用DPO训练LLaMA2-7B在fp16,batch_size为2时,需要70GB显存。
|
101 |
+
|
102 |
+
```shell
|
103 |
+
sh run_dpo.sh
|
104 |
+
```
|
MedicalGPT-main/docs/wechat.jpeg
ADDED
MedicalGPT-main/dpo_training.py
ADDED
@@ -0,0 +1,495 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description: Train a model from SFT using DPO
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from glob import glob
|
10 |
+
from typing import Dict, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset
|
14 |
+
from loguru import logger
|
15 |
+
from peft import LoraConfig, TaskType
|
16 |
+
from transformers import (
|
17 |
+
AutoConfig,
|
18 |
+
BloomForCausalLM,
|
19 |
+
AutoModelForCausalLM,
|
20 |
+
AutoModel,
|
21 |
+
LlamaTokenizer,
|
22 |
+
LlamaForCausalLM,
|
23 |
+
BloomTokenizerFast,
|
24 |
+
AutoTokenizer,
|
25 |
+
HfArgumentParser,
|
26 |
+
TrainingArguments,
|
27 |
+
BitsAndBytesConfig,
|
28 |
+
)
|
29 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
30 |
+
from trl import DPOTrainer
|
31 |
+
|
32 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
|
33 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
34 |
+
|
35 |
+
MODEL_CLASSES = {
|
36 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
37 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
38 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
39 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
40 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class ScriptArguments:
|
46 |
+
"""
|
47 |
+
The name of the Casual LM model we wish to fine with DPO
|
48 |
+
"""
|
49 |
+
# Model arguments
|
50 |
+
model_type: str = field(
|
51 |
+
default=None,
|
52 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
53 |
+
)
|
54 |
+
model_name_or_path: Optional[str] = field(
|
55 |
+
default=None, metadata={"help": "The model checkpoint for weights initialization."}
|
56 |
+
)
|
57 |
+
tokenizer_name_or_path: Optional[str] = field(
|
58 |
+
default=None, metadata={"help": "The tokenizer for weights initialization."}
|
59 |
+
)
|
60 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
61 |
+
load_in_4bit: bool = field(default=False, metadata={"help": "Whether to load the model in 4bit mode or not."})
|
62 |
+
cache_dir: Optional[str] = field(
|
63 |
+
default=None,
|
64 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
65 |
+
)
|
66 |
+
use_fast_tokenizer: bool = field(
|
67 |
+
default=False,
|
68 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
69 |
+
)
|
70 |
+
torch_dtype: Optional[str] = field(
|
71 |
+
default=None,
|
72 |
+
metadata={
|
73 |
+
"help": (
|
74 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
75 |
+
"dtype will be automatically derived from the model's weights."
|
76 |
+
),
|
77 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
78 |
+
},
|
79 |
+
)
|
80 |
+
device_map: Optional[str] = field(
|
81 |
+
default="auto",
|
82 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
83 |
+
)
|
84 |
+
trust_remote_code: bool = field(
|
85 |
+
default=True,
|
86 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
87 |
+
)
|
88 |
+
# Dataset arguments
|
89 |
+
dataset_name: Optional[str] = field(
|
90 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
91 |
+
)
|
92 |
+
dataset_config_name: Optional[str] = field(
|
93 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
94 |
+
)
|
95 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
|
96 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
|
97 |
+
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
|
98 |
+
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "Train batch size per device"})
|
99 |
+
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "Eval batch size per device"})
|
100 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
101 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
102 |
+
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
|
103 |
+
max_train_samples: Optional[int] = field(
|
104 |
+
default=None,
|
105 |
+
metadata={
|
106 |
+
"help": (
|
107 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
108 |
+
"value if set."
|
109 |
+
)
|
110 |
+
},
|
111 |
+
)
|
112 |
+
max_eval_samples: Optional[int] = field(
|
113 |
+
default=None,
|
114 |
+
metadata={
|
115 |
+
"help": (
|
116 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
117 |
+
"value if set."
|
118 |
+
)
|
119 |
+
},
|
120 |
+
)
|
121 |
+
overwrite_cache: bool = field(
|
122 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
123 |
+
)
|
124 |
+
validation_split_percentage: Optional[int] = field(
|
125 |
+
default=1,
|
126 |
+
metadata={
|
127 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
128 |
+
},
|
129 |
+
)
|
130 |
+
preprocessing_num_workers: Optional[int] = field(
|
131 |
+
default=4, metadata={"help": "The number of processes to use for the preprocessing."},
|
132 |
+
)
|
133 |
+
# Training arguments
|
134 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
135 |
+
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
|
136 |
+
target_modules: Optional[str] = field(default=None)
|
137 |
+
lora_rank: Optional[int] = field(default=8)
|
138 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
139 |
+
lora_alpha: Optional[float] = field(default=16.0)
|
140 |
+
peft_path: Optional[str] = field(default=None)
|
141 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
142 |
+
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
|
143 |
+
beta: Optional[float] = field(default=0.1, metadata={"help": "The beta parameter for DPO loss"})
|
144 |
+
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "Learning rate"})
|
145 |
+
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "The lr scheduler type"})
|
146 |
+
warmup_steps: Optional[int] = field(default=100, metadata={"help": "The number of warmup steps"})
|
147 |
+
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "The weight decay"})
|
148 |
+
optim: Optional[str] = field(default="adamw_hf", metadata={"help": "The optimizer type"})
|
149 |
+
fp16: Optional[bool] = field(default=True, metadata={"help": "Whether to use fp16"})
|
150 |
+
bf16: Optional[bool] = field(default=False, metadata={"help": "Whether to use bf16"})
|
151 |
+
gradient_checkpointing: Optional[bool] = field(
|
152 |
+
default=True, metadata={"help": "Whether to use gradient checkpointing"}
|
153 |
+
)
|
154 |
+
gradient_accumulation_steps: Optional[int] = field(
|
155 |
+
default=4, metadata={"help": "The number of gradient accumulation steps"}
|
156 |
+
)
|
157 |
+
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
|
158 |
+
eval_steps: Optional[int] = field(default=50, metadata={"help": "X steps to evaluate the model"})
|
159 |
+
logging_steps: Optional[int] = field(default=1, metadata={"help": "X steps to log the model"})
|
160 |
+
output_dir: Optional[str] = field(default="outputs-dpo", metadata={"help": "The output directory"})
|
161 |
+
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
|
162 |
+
eval_strategy: Optional[str] = field(default="steps", metadata={"help": "Evaluation strategy"})
|
163 |
+
remove_unused_columns: Optional[bool] = field(
|
164 |
+
default=False,
|
165 |
+
metadata={"help": "Remove unused columns from the dataset if `datasets.Dataset` is used"},
|
166 |
+
)
|
167 |
+
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
|
168 |
+
|
169 |
+
def __post_init__(self):
|
170 |
+
if self.model_type is None:
|
171 |
+
raise ValueError("You must specify a valid model_type to run training.")
|
172 |
+
if self.model_name_or_path is None:
|
173 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
174 |
+
|
175 |
+
|
176 |
+
def print_trainable_parameters(model):
|
177 |
+
"""
|
178 |
+
Prints the number of trainable parameters in the model.
|
179 |
+
"""
|
180 |
+
trainable_params = 0
|
181 |
+
all_param = 0
|
182 |
+
for _, param in model.named_parameters():
|
183 |
+
all_param += param.numel()
|
184 |
+
if param.requires_grad:
|
185 |
+
trainable_params += param.numel()
|
186 |
+
print(
|
187 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
188 |
+
)
|
189 |
+
|
190 |
+
|
191 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
192 |
+
"""Find all linear layer names in the model. reference from qlora paper."""
|
193 |
+
cls = torch.nn.Linear
|
194 |
+
if int4 or int8:
|
195 |
+
import bitsandbytes as bnb
|
196 |
+
if int4:
|
197 |
+
cls = bnb.nn.Linear4bit
|
198 |
+
elif int8:
|
199 |
+
cls = bnb.nn.Linear8bitLt
|
200 |
+
lora_module_names = set()
|
201 |
+
for name, module in peft_model.named_modules():
|
202 |
+
if isinstance(module, cls):
|
203 |
+
# last layer is not add to lora_module_names
|
204 |
+
if 'lm_head' in name:
|
205 |
+
continue
|
206 |
+
if 'output_layer' in name:
|
207 |
+
continue
|
208 |
+
names = name.split('.')
|
209 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
210 |
+
return sorted(lora_module_names)
|
211 |
+
|
212 |
+
|
213 |
+
def return_prompt_and_responses(examples) -> Dict[str, str]:
|
214 |
+
"""Load the paired dataset and convert it to the necessary format.
|
215 |
+
|
216 |
+
The dataset is converted to a dictionary with the following structure:
|
217 |
+
{
|
218 |
+
'prompt': List[str],
|
219 |
+
'chosen': List[str],
|
220 |
+
'rejected': List[str],
|
221 |
+
}
|
222 |
+
|
223 |
+
Prompts are structured as follows:
|
224 |
+
"Question: " + <prompt> + "\n\nAnswer: "
|
225 |
+
"""
|
226 |
+
return {
|
227 |
+
"prompt": ["Question: " + question + "\n\nAnswer: " for question in examples["question"]],
|
228 |
+
"chosen": examples["response_chosen"],
|
229 |
+
"rejected": examples["response_rejected"],
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
def main():
|
234 |
+
parser = HfArgumentParser(ScriptArguments)
|
235 |
+
args = parser.parse_args_into_dataclasses()[0]
|
236 |
+
logger.info(f"Parse args: {args}")
|
237 |
+
|
238 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
239 |
+
if args.model_type == 'bloom':
|
240 |
+
args.use_fast_tokenizer = True
|
241 |
+
# Load tokenizer
|
242 |
+
tokenizer_kwargs = {
|
243 |
+
"cache_dir": args.cache_dir,
|
244 |
+
"use_fast": args.use_fast_tokenizer,
|
245 |
+
"trust_remote_code": args.trust_remote_code,
|
246 |
+
}
|
247 |
+
tokenizer_name_or_path = args.tokenizer_name_or_path
|
248 |
+
if not tokenizer_name_or_path:
|
249 |
+
tokenizer_name_or_path = args.model_name_or_path
|
250 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
251 |
+
if tokenizer.pad_token_id is None:
|
252 |
+
tokenizer.pad_token_id = 0 # set as the <unk> token
|
253 |
+
|
254 |
+
# Get datasets
|
255 |
+
if args.dataset_name is not None:
|
256 |
+
# Downloading and loading a dataset from the hub.
|
257 |
+
raw_datasets = load_dataset(
|
258 |
+
args.dataset_name,
|
259 |
+
args.dataset_config_name,
|
260 |
+
cache_dir=args.cache_dir,
|
261 |
+
)
|
262 |
+
if "validation" not in raw_datasets.keys():
|
263 |
+
raw_datasets["validation"] = load_dataset(
|
264 |
+
args.dataset_name,
|
265 |
+
args.dataset_config_name,
|
266 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
267 |
+
cache_dir=args.cache_dir,
|
268 |
+
)
|
269 |
+
raw_datasets["train"] = load_dataset(
|
270 |
+
args.dataset_name,
|
271 |
+
args.dataset_config_name,
|
272 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
273 |
+
cache_dir=args.cache_dir,
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
data_files = {}
|
277 |
+
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
|
278 |
+
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
|
279 |
+
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
|
280 |
+
logger.info(f"train files: {', '.join(train_data_files)}")
|
281 |
+
data_files["train"] = train_data_files
|
282 |
+
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
|
283 |
+
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
284 |
+
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
|
285 |
+
logger.info(f"eval files: {', '.join(eval_data_files)}")
|
286 |
+
data_files["validation"] = eval_data_files
|
287 |
+
raw_datasets = load_dataset(
|
288 |
+
'json',
|
289 |
+
data_files=data_files,
|
290 |
+
cache_dir=args.cache_dir,
|
291 |
+
)
|
292 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
293 |
+
if "validation" not in raw_datasets.keys():
|
294 |
+
raw_datasets["validation"] = load_dataset(
|
295 |
+
'json',
|
296 |
+
data_files=data_files,
|
297 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
298 |
+
cache_dir=args.cache_dir,
|
299 |
+
)
|
300 |
+
raw_datasets["train"] = load_dataset(
|
301 |
+
'json',
|
302 |
+
data_files=data_files,
|
303 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
304 |
+
cache_dir=args.cache_dir,
|
305 |
+
)
|
306 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
307 |
+
|
308 |
+
# Preprocessing the datasets
|
309 |
+
max_source_length = args.max_source_length
|
310 |
+
max_target_length = args.max_target_length
|
311 |
+
full_max_length = max_source_length + max_target_length
|
312 |
+
|
313 |
+
# Preprocess the dataset
|
314 |
+
train_dataset = None
|
315 |
+
max_train_samples = 0
|
316 |
+
if args.do_train:
|
317 |
+
if "train" not in raw_datasets:
|
318 |
+
raise ValueError("--do_train requires a train dataset")
|
319 |
+
train_dataset = raw_datasets['train']
|
320 |
+
max_train_samples = len(train_dataset)
|
321 |
+
if args.max_train_samples is not None and args.max_train_samples > 0:
|
322 |
+
max_train_samples = min(len(train_dataset), args.max_train_samples)
|
323 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
324 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
325 |
+
tokenized_dataset = train_dataset.shuffle().map(
|
326 |
+
return_prompt_and_responses,
|
327 |
+
batched=True,
|
328 |
+
num_proc=args.preprocessing_num_workers,
|
329 |
+
remove_columns=train_dataset.column_names,
|
330 |
+
load_from_cache_file=not args.overwrite_cache,
|
331 |
+
desc="Running tokenizer on dataset",
|
332 |
+
)
|
333 |
+
train_dataset = tokenized_dataset.filter(
|
334 |
+
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
|
335 |
+
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
|
336 |
+
)
|
337 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
338 |
+
logger.debug("First train example:")
|
339 |
+
logger.debug(train_dataset[0]['prompt'] + train_dataset[0]['chosen'])
|
340 |
+
|
341 |
+
eval_dataset = None
|
342 |
+
max_eval_samples = 0
|
343 |
+
if args.do_eval:
|
344 |
+
if "validation" not in raw_datasets:
|
345 |
+
raise ValueError("--do_eval requires a validation dataset")
|
346 |
+
eval_dataset = raw_datasets["validation"]
|
347 |
+
max_eval_samples = len(eval_dataset)
|
348 |
+
if args.max_eval_samples is not None and args.max_eval_samples > 0:
|
349 |
+
max_eval_samples = min(len(eval_dataset), args.max_eval_samples)
|
350 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
351 |
+
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
|
352 |
+
eval_dataset = eval_dataset.map(
|
353 |
+
return_prompt_and_responses,
|
354 |
+
batched=True,
|
355 |
+
num_proc=args.preprocessing_num_workers,
|
356 |
+
remove_columns=eval_dataset.column_names,
|
357 |
+
load_from_cache_file=not args.overwrite_cache,
|
358 |
+
desc="Running tokenizer on dataset",
|
359 |
+
)
|
360 |
+
eval_dataset = eval_dataset.filter(
|
361 |
+
lambda x: 0 < len(x['prompt'] + x['chosen']) <= full_max_length
|
362 |
+
and 0 < len(x['prompt'] + x['rejected']) <= full_max_length
|
363 |
+
)
|
364 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
365 |
+
logger.debug("First eval example:")
|
366 |
+
logger.debug(eval_dataset[0]['prompt'] + eval_dataset[0]['chosen'])
|
367 |
+
|
368 |
+
logger.info("Loading model")
|
369 |
+
torch_dtype = (
|
370 |
+
args.torch_dtype
|
371 |
+
if args.torch_dtype in ["auto", None]
|
372 |
+
else getattr(torch, args.torch_dtype)
|
373 |
+
)
|
374 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
375 |
+
ddp = world_size != 1
|
376 |
+
if ddp:
|
377 |
+
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
378 |
+
if args.qlora and is_deepspeed_zero3_enabled():
|
379 |
+
logger.warning("ZeRO3 are both currently incompatible with QLoRA.")
|
380 |
+
config = config_class.from_pretrained(
|
381 |
+
args.model_name_or_path,
|
382 |
+
trust_remote_code=args.trust_remote_code,
|
383 |
+
torch_dtype=torch_dtype,
|
384 |
+
cache_dir=args.cache_dir
|
385 |
+
)
|
386 |
+
model = model_class.from_pretrained(
|
387 |
+
args.model_name_or_path,
|
388 |
+
config=config,
|
389 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
390 |
+
device_map=args.device_map,
|
391 |
+
trust_remote_code=args.trust_remote_code,
|
392 |
+
quantization_config=BitsAndBytesConfig(
|
393 |
+
load_in_4bit=True,
|
394 |
+
bnb_4bit_use_double_quant=True,
|
395 |
+
bnb_4bit_quant_type="nf4",
|
396 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
397 |
+
) if args.qlora else None,
|
398 |
+
)
|
399 |
+
model_ref = model_class.from_pretrained(
|
400 |
+
args.model_name_or_path,
|
401 |
+
config=config,
|
402 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
403 |
+
device_map=args.device_map,
|
404 |
+
trust_remote_code=args.trust_remote_code,
|
405 |
+
quantization_config=BitsAndBytesConfig(
|
406 |
+
load_in_4bit=True,
|
407 |
+
bnb_4bit_use_double_quant=True,
|
408 |
+
bnb_4bit_quant_type="nf4",
|
409 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
410 |
+
) if args.qlora else None,
|
411 |
+
)
|
412 |
+
|
413 |
+
# Initialize our Trainer
|
414 |
+
if args.gradient_checkpointing:
|
415 |
+
model.gradient_checkpointing_enable()
|
416 |
+
model.config.use_cache = False
|
417 |
+
else:
|
418 |
+
model.config.use_cache = True
|
419 |
+
|
420 |
+
training_args = TrainingArguments(
|
421 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
422 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
423 |
+
max_steps=args.max_steps,
|
424 |
+
logging_steps=args.logging_steps,
|
425 |
+
save_steps=args.save_steps,
|
426 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
427 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
428 |
+
learning_rate=args.learning_rate,
|
429 |
+
evaluation_strategy=args.eval_strategy,
|
430 |
+
eval_steps=args.eval_steps,
|
431 |
+
output_dir=args.output_dir,
|
432 |
+
report_to=args.report_to,
|
433 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
434 |
+
warmup_steps=args.warmup_steps,
|
435 |
+
optim=args.optim,
|
436 |
+
bf16=args.bf16,
|
437 |
+
fp16=args.fp16,
|
438 |
+
remove_unused_columns=args.remove_unused_columns,
|
439 |
+
run_name=f"dpo_{args.model_type}",
|
440 |
+
)
|
441 |
+
|
442 |
+
# Initialize DPO trainer
|
443 |
+
target_modules = args.target_modules.split(',') if args.target_modules else None
|
444 |
+
if target_modules and 'all' in target_modules:
|
445 |
+
target_modules = find_all_linear_names(model, int4=args.load_in_4bit, int8=args.load_in_8bit)
|
446 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
447 |
+
peft_config = LoraConfig(
|
448 |
+
task_type=TaskType.CAUSAL_LM,
|
449 |
+
target_modules=target_modules,
|
450 |
+
inference_mode=False,
|
451 |
+
r=args.lora_rank,
|
452 |
+
lora_alpha=args.lora_alpha,
|
453 |
+
lora_dropout=args.lora_dropout,
|
454 |
+
)
|
455 |
+
trainer = DPOTrainer(
|
456 |
+
model,
|
457 |
+
model_ref,
|
458 |
+
args=training_args,
|
459 |
+
beta=args.beta,
|
460 |
+
train_dataset=train_dataset,
|
461 |
+
eval_dataset=eval_dataset,
|
462 |
+
tokenizer=tokenizer,
|
463 |
+
peft_config=peft_config if args.use_peft else None,
|
464 |
+
max_prompt_length=args.max_source_length,
|
465 |
+
max_length=full_max_length,
|
466 |
+
)
|
467 |
+
print_trainable_parameters(trainer.model)
|
468 |
+
|
469 |
+
# Training
|
470 |
+
if args.do_train:
|
471 |
+
logger.info("*** Train ***")
|
472 |
+
train_result = trainer.train()
|
473 |
+
metrics = train_result.metrics
|
474 |
+
metrics["train_samples"] = max_train_samples
|
475 |
+
logger.debug(f"Training metrics: {metrics}")
|
476 |
+
trainer.log_metrics("train", metrics)
|
477 |
+
trainer.save_metrics("train", metrics)
|
478 |
+
trainer.save_state()
|
479 |
+
logger.info(f"Saving model checkpoint to {args.output_dir}")
|
480 |
+
trainer.save_model(args.output_dir)
|
481 |
+
tokenizer.save_pretrained(args.output_dir)
|
482 |
+
trainer.model.save_pretrained(args.output_dir)
|
483 |
+
|
484 |
+
# Evaluation
|
485 |
+
if args.do_eval and trainer.is_world_process_zero():
|
486 |
+
logger.info("*** Evaluate ***")
|
487 |
+
metrics = trainer.evaluate()
|
488 |
+
metrics["eval_samples"] = max_eval_samples
|
489 |
+
logger.debug(f"Eval metrics: {metrics}")
|
490 |
+
trainer.log_metrics("eval", metrics)
|
491 |
+
trainer.save_metrics("eval", metrics)
|
492 |
+
|
493 |
+
|
494 |
+
if __name__ == "__main__":
|
495 |
+
main()
|
MedicalGPT-main/gradio_demo.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description:
|
5 |
+
|
6 |
+
pip install gradio
|
7 |
+
pip install mdtex2html
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
from threading import Thread
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import mdtex2html
|
15 |
+
import torch
|
16 |
+
from peft import PeftModel
|
17 |
+
from transformers import (
|
18 |
+
AutoModel,
|
19 |
+
AutoTokenizer,
|
20 |
+
AutoModelForCausalLM,
|
21 |
+
BloomForCausalLM,
|
22 |
+
BloomTokenizerFast,
|
23 |
+
LlamaTokenizer,
|
24 |
+
LlamaForCausalLM,
|
25 |
+
GenerationConfig,
|
26 |
+
TextIteratorStreamer,
|
27 |
+
)
|
28 |
+
|
29 |
+
from supervised_finetuning import get_conv_template
|
30 |
+
|
31 |
+
MODEL_CLASSES = {
|
32 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
33 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
34 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
35 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
36 |
+
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
@torch.inference_mode()
|
41 |
+
def stream_generate_answer(
|
42 |
+
model,
|
43 |
+
tokenizer,
|
44 |
+
prompt,
|
45 |
+
device,
|
46 |
+
max_new_tokens=512,
|
47 |
+
temperature=0.7,
|
48 |
+
top_p=0.8,
|
49 |
+
repetition_penalty=1.0,
|
50 |
+
context_len=2048,
|
51 |
+
):
|
52 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
|
53 |
+
input_ids = tokenizer(prompt).input_ids
|
54 |
+
max_src_len = context_len - max_new_tokens - 8
|
55 |
+
input_ids = input_ids[-max_src_len:]
|
56 |
+
generation_kwargs = dict(
|
57 |
+
input_ids=torch.as_tensor([input_ids]).to(device),
|
58 |
+
max_new_tokens=max_new_tokens,
|
59 |
+
temperature=temperature,
|
60 |
+
top_p=top_p,
|
61 |
+
repetition_penalty=repetition_penalty,
|
62 |
+
streamer=streamer,
|
63 |
+
)
|
64 |
+
|
65 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
66 |
+
thread.start()
|
67 |
+
|
68 |
+
yield from streamer
|
69 |
+
|
70 |
+
|
71 |
+
def main():
|
72 |
+
parser = argparse.ArgumentParser()
|
73 |
+
parser.add_argument('--model_type', default=None, type=str, required=True)
|
74 |
+
parser.add_argument('--base_model', default=None, type=str, required=True)
|
75 |
+
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
|
76 |
+
parser.add_argument('--tokenizer_path', default=None, type=str)
|
77 |
+
parser.add_argument('--template_name', default="vicuna", type=str,
|
78 |
+
help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
|
79 |
+
parser.add_argument('--gpus', default="0", type=str)
|
80 |
+
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
|
81 |
+
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
|
82 |
+
args = parser.parse_args()
|
83 |
+
if args.only_cpu is True:
|
84 |
+
args.gpus = ""
|
85 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
86 |
+
|
87 |
+
def postprocess(self, y):
|
88 |
+
if y is None:
|
89 |
+
return []
|
90 |
+
for i, (message, response) in enumerate(y):
|
91 |
+
y[i] = (
|
92 |
+
None if message is None else mdtex2html.convert((message)),
|
93 |
+
None if response is None else mdtex2html.convert(response),
|
94 |
+
)
|
95 |
+
return y
|
96 |
+
|
97 |
+
gr.Chatbot.postprocess = postprocess
|
98 |
+
|
99 |
+
load_type = torch.float16
|
100 |
+
if torch.cuda.is_available():
|
101 |
+
device = torch.device(0)
|
102 |
+
else:
|
103 |
+
device = torch.device('cpu')
|
104 |
+
|
105 |
+
if args.tokenizer_path is None:
|
106 |
+
args.tokenizer_path = args.base_model
|
107 |
+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
108 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
109 |
+
base_model = model_class.from_pretrained(
|
110 |
+
args.base_model,
|
111 |
+
load_in_8bit=False,
|
112 |
+
torch_dtype=load_type,
|
113 |
+
low_cpu_mem_usage=True,
|
114 |
+
device_map='auto',
|
115 |
+
trust_remote_code=True,
|
116 |
+
)
|
117 |
+
try:
|
118 |
+
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
|
119 |
+
except OSError:
|
120 |
+
print("Failed to load generation config, use default.")
|
121 |
+
if args.resize_emb:
|
122 |
+
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
|
123 |
+
tokenzier_vocab_size = len(tokenizer)
|
124 |
+
print(f"Vocab of the base model: {model_vocab_size}")
|
125 |
+
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
|
126 |
+
if model_vocab_size != tokenzier_vocab_size:
|
127 |
+
print("Resize model embeddings to fit tokenizer")
|
128 |
+
base_model.resize_token_embeddings(tokenzier_vocab_size)
|
129 |
+
if args.lora_model:
|
130 |
+
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
|
131 |
+
print("loaded lora model")
|
132 |
+
else:
|
133 |
+
model = base_model
|
134 |
+
if device == torch.device('cpu'):
|
135 |
+
model.float()
|
136 |
+
|
137 |
+
model.eval()
|
138 |
+
|
139 |
+
def reset_user_input():
|
140 |
+
return gr.update(value='')
|
141 |
+
|
142 |
+
def reset_state():
|
143 |
+
return [], []
|
144 |
+
|
145 |
+
prompt_template = get_conv_template(args.template_name)
|
146 |
+
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
|
147 |
+
history = []
|
148 |
+
|
149 |
+
def predict(
|
150 |
+
input,
|
151 |
+
chatbot,
|
152 |
+
history,
|
153 |
+
max_new_tokens,
|
154 |
+
temperature,
|
155 |
+
top_p
|
156 |
+
):
|
157 |
+
now_input = input
|
158 |
+
chatbot.append((input, ""))
|
159 |
+
history = history or []
|
160 |
+
history.append([now_input, ''])
|
161 |
+
|
162 |
+
prompt = prompt_template.get_prompt(messages=history)
|
163 |
+
response = ""
|
164 |
+
|
165 |
+
for new_text in stream_generate_answer(
|
166 |
+
model,
|
167 |
+
tokenizer,
|
168 |
+
prompt,
|
169 |
+
device,
|
170 |
+
max_new_tokens=max_new_tokens,
|
171 |
+
temperature=temperature,
|
172 |
+
top_p=top_p,
|
173 |
+
):
|
174 |
+
stop = False
|
175 |
+
pos = new_text.find(stop_str)
|
176 |
+
if pos != -1:
|
177 |
+
new_text = new_text[:pos]
|
178 |
+
stop = True
|
179 |
+
response += new_text
|
180 |
+
new_history = history + [(now_input, response)]
|
181 |
+
chatbot[-1] = (now_input, response)
|
182 |
+
yield chatbot, new_history
|
183 |
+
if stop:
|
184 |
+
break
|
185 |
+
|
186 |
+
with gr.Blocks() as demo:
|
187 |
+
gr.HTML("""<h1 align="center">MedicalGPT</h1>""")
|
188 |
+
gr.Markdown(
|
189 |
+
"> 为了促进医疗行业大模型的开放研究,本项目开源了MedicalGPT医疗大模型")
|
190 |
+
chatbot = gr.Chatbot()
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column(scale=4):
|
193 |
+
with gr.Column(scale=12):
|
194 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
195 |
+
container=False)
|
196 |
+
with gr.Column(min_width=32, scale=1):
|
197 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
198 |
+
with gr.Column(scale=1):
|
199 |
+
emptyBtn = gr.Button("Clear History")
|
200 |
+
max_length = gr.Slider(
|
201 |
+
0, 4096, value=512, step=1.0, label="Maximum length", interactive=True)
|
202 |
+
top_p = gr.Slider(0, 1, value=0.8, step=0.01,
|
203 |
+
label="Top P", interactive=True)
|
204 |
+
temperature = gr.Slider(
|
205 |
+
0, 1, value=0.7, step=0.01, label="Temperature", interactive=True)
|
206 |
+
history = gr.State([])
|
207 |
+
submitBtn.click(predict, [user_input, chatbot, history, max_length, temperature, top_p], [chatbot, history],
|
208 |
+
show_progress=True)
|
209 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
210 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
211 |
+
demo.queue().launch(share=False, inbrowser=True, server_name='0.0.0.0', server_port=8082)
|
212 |
+
|
213 |
+
|
214 |
+
if __name__ == '__main__':
|
215 |
+
main()
|
MedicalGPT-main/inference.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from threading import Thread
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from peft import PeftModel
|
13 |
+
from transformers import (
|
14 |
+
AutoModel,
|
15 |
+
AutoModelForCausalLM,
|
16 |
+
AutoTokenizer,
|
17 |
+
BloomForCausalLM,
|
18 |
+
BloomTokenizerFast,
|
19 |
+
LlamaTokenizer,
|
20 |
+
LlamaForCausalLM,
|
21 |
+
TextIteratorStreamer,
|
22 |
+
GenerationConfig,
|
23 |
+
)
|
24 |
+
|
25 |
+
from supervised_finetuning import get_conv_template
|
26 |
+
|
27 |
+
MODEL_CLASSES = {
|
28 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
29 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
30 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
31 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
32 |
+
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
@torch.inference_mode()
|
37 |
+
def stream_generate_answer(
|
38 |
+
model,
|
39 |
+
tokenizer,
|
40 |
+
prompt,
|
41 |
+
device,
|
42 |
+
do_print=True,
|
43 |
+
max_new_tokens=512,
|
44 |
+
temperature=0.7,
|
45 |
+
repetition_penalty=1.0,
|
46 |
+
context_len=2048,
|
47 |
+
stop_str="</s>",
|
48 |
+
):
|
49 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
|
50 |
+
input_ids = tokenizer(prompt).input_ids
|
51 |
+
max_src_len = context_len - max_new_tokens - 8
|
52 |
+
input_ids = input_ids[-max_src_len:]
|
53 |
+
generation_kwargs = dict(
|
54 |
+
input_ids=torch.as_tensor([input_ids]).to(device),
|
55 |
+
max_new_tokens=max_new_tokens,
|
56 |
+
temperature=temperature,
|
57 |
+
repetition_penalty=repetition_penalty,
|
58 |
+
streamer=streamer,
|
59 |
+
)
|
60 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
61 |
+
thread.start()
|
62 |
+
|
63 |
+
generated_text = ""
|
64 |
+
for new_text in streamer:
|
65 |
+
stop = False
|
66 |
+
pos = new_text.find(stop_str)
|
67 |
+
if pos != -1:
|
68 |
+
new_text = new_text[:pos]
|
69 |
+
stop = True
|
70 |
+
generated_text += new_text
|
71 |
+
if do_print:
|
72 |
+
print(new_text, end="", flush=True)
|
73 |
+
if stop:
|
74 |
+
break
|
75 |
+
if do_print:
|
76 |
+
print()
|
77 |
+
return generated_text
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument('--model_type', default=None, type=str, required=True)
|
83 |
+
parser.add_argument('--base_model', default=None, type=str, required=True)
|
84 |
+
parser.add_argument('--lora_model', default="", type=str, help="If None, perform inference on the base model")
|
85 |
+
parser.add_argument('--tokenizer_path', default=None, type=str)
|
86 |
+
parser.add_argument('--template_name', default="vicuna", type=str,
|
87 |
+
help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.")
|
88 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
89 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
90 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
91 |
+
parser.add_argument('--data_file', default=None, type=str,
|
92 |
+
help="A file that contains instructions (one instruction per line)")
|
93 |
+
parser.add_argument('--interactive', action='store_true', help="run in the instruction mode (single-turn)")
|
94 |
+
parser.add_argument('--predictions_file', default='./predictions.json', type=str)
|
95 |
+
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
|
96 |
+
parser.add_argument('--gpus', default="0", type=str)
|
97 |
+
parser.add_argument('--only_cpu', action='store_true', help='only use CPU for inference')
|
98 |
+
args = parser.parse_args()
|
99 |
+
print(args)
|
100 |
+
if args.only_cpu is True:
|
101 |
+
args.gpus = ""
|
102 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
103 |
+
load_type = torch.float16
|
104 |
+
if torch.cuda.is_available():
|
105 |
+
device = torch.device(0)
|
106 |
+
else:
|
107 |
+
device = torch.device('cpu')
|
108 |
+
if args.tokenizer_path is None:
|
109 |
+
args.tokenizer_path = args.base_model
|
110 |
+
|
111 |
+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
112 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
113 |
+
base_model = model_class.from_pretrained(
|
114 |
+
args.base_model,
|
115 |
+
load_in_8bit=False,
|
116 |
+
torch_dtype=load_type,
|
117 |
+
low_cpu_mem_usage=True,
|
118 |
+
device_map='auto',
|
119 |
+
trust_remote_code=True,
|
120 |
+
)
|
121 |
+
try:
|
122 |
+
base_model.generation_config = GenerationConfig.from_pretrained(args.base_model, trust_remote_code=True)
|
123 |
+
except OSError:
|
124 |
+
print("Failed to load generation config, use default.")
|
125 |
+
if args.resize_emb:
|
126 |
+
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
|
127 |
+
tokenzier_vocab_size = len(tokenizer)
|
128 |
+
print(f"Vocab of the base model: {model_vocab_size}")
|
129 |
+
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
|
130 |
+
if model_vocab_size != tokenzier_vocab_size:
|
131 |
+
print("Resize model embeddings to fit tokenizer")
|
132 |
+
base_model.resize_token_embeddings(tokenzier_vocab_size)
|
133 |
+
|
134 |
+
if args.lora_model:
|
135 |
+
model = PeftModel.from_pretrained(base_model, args.lora_model, torch_dtype=load_type, device_map='auto')
|
136 |
+
print("Loaded lora model")
|
137 |
+
else:
|
138 |
+
model = base_model
|
139 |
+
if device == torch.device('cpu'):
|
140 |
+
model.float()
|
141 |
+
model.eval()
|
142 |
+
print(tokenizer)
|
143 |
+
# test data
|
144 |
+
if args.data_file is None:
|
145 |
+
examples = ["介绍下北京", "乙肝和丙肝的区别?"]
|
146 |
+
else:
|
147 |
+
with open(args.data_file, 'r') as f:
|
148 |
+
examples = [l.strip() for l in f.readlines()]
|
149 |
+
print("first 10 examples:")
|
150 |
+
for example in examples[:10]:
|
151 |
+
print(example)
|
152 |
+
|
153 |
+
# Chat
|
154 |
+
prompt_template = get_conv_template(args.template_name)
|
155 |
+
stop_str = tokenizer.eos_token if tokenizer.eos_token else prompt_template.stop_str
|
156 |
+
|
157 |
+
if args.interactive:
|
158 |
+
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
159 |
+
history = []
|
160 |
+
while True:
|
161 |
+
try:
|
162 |
+
query = input(f"{prompt_template.roles[0]}: ")
|
163 |
+
except UnicodeDecodeError:
|
164 |
+
print("Detected decoding error at the inputs, please try again.")
|
165 |
+
continue
|
166 |
+
except Exception:
|
167 |
+
raise
|
168 |
+
if query == "":
|
169 |
+
print("Please input text, try again.")
|
170 |
+
continue
|
171 |
+
if query.strip() == "exit":
|
172 |
+
print("exit...")
|
173 |
+
break
|
174 |
+
if query.strip() == "clear":
|
175 |
+
history = []
|
176 |
+
print("history cleared.")
|
177 |
+
continue
|
178 |
+
|
179 |
+
print(f"{prompt_template.roles[1]}: ", end="", flush=True)
|
180 |
+
|
181 |
+
history.append([query, ''])
|
182 |
+
prompt = prompt_template.get_prompt(messages=history)
|
183 |
+
response = stream_generate_answer(
|
184 |
+
model,
|
185 |
+
tokenizer,
|
186 |
+
prompt,
|
187 |
+
device,
|
188 |
+
do_print=True,
|
189 |
+
max_new_tokens=args.max_new_tokens,
|
190 |
+
temperature=args.temperature,
|
191 |
+
repetition_penalty=args.repetition_penalty,
|
192 |
+
stop_str=stop_str,
|
193 |
+
)
|
194 |
+
if history:
|
195 |
+
history[-1][-1] = response.strip()
|
196 |
+
else:
|
197 |
+
print("Start inference.")
|
198 |
+
results = []
|
199 |
+
for index, example in enumerate(examples):
|
200 |
+
# Single turn inference
|
201 |
+
history = [[example, '']]
|
202 |
+
prompt = prompt_template.get_prompt(messages=history)
|
203 |
+
response = stream_generate_answer(
|
204 |
+
model,
|
205 |
+
tokenizer,
|
206 |
+
prompt,
|
207 |
+
device,
|
208 |
+
do_print=False,
|
209 |
+
max_new_tokens=args.max_new_tokens,
|
210 |
+
temperature=args.temperature,
|
211 |
+
repetition_penalty=args.repetition_penalty,
|
212 |
+
stop_str=stop_str,
|
213 |
+
)
|
214 |
+
response = response.strip()
|
215 |
+
print(f"======={index}=======")
|
216 |
+
print(f"Input: {example}\n")
|
217 |
+
print(f"Output: {response}\n")
|
218 |
+
results.append({"Input": prompt, "Output": response})
|
219 |
+
|
220 |
+
with open(args.predictions_file, 'w', encoding='utf-8') as f:
|
221 |
+
json.dump(results, f, ensure_ascii=False, indent=2)
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == '__main__':
|
225 |
+
main()
|
MedicalGPT-main/merge_peft_adapter.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description:
|
5 |
+
|
6 |
+
Usage:
|
7 |
+
python merge_peft_adapter.py \
|
8 |
+
--base_model_name_or_path path/to/llama/model \
|
9 |
+
--tokenizer_path path/to/llama/tokenizer \
|
10 |
+
--peft_model_path path/to/lora/model \
|
11 |
+
--output_dir path/to/output/dir
|
12 |
+
|
13 |
+
after merged, chatglm and baichuan model need copy python script to output dir.
|
14 |
+
"""
|
15 |
+
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from peft import PeftModel, PeftConfig
|
20 |
+
from transformers import (
|
21 |
+
AutoModel,
|
22 |
+
AutoTokenizer,
|
23 |
+
BloomForCausalLM,
|
24 |
+
BloomTokenizerFast,
|
25 |
+
AutoModelForCausalLM,
|
26 |
+
LlamaTokenizer,
|
27 |
+
LlamaForCausalLM,
|
28 |
+
AutoModelForSequenceClassification,
|
29 |
+
)
|
30 |
+
|
31 |
+
MODEL_CLASSES = {
|
32 |
+
"bloom": (BloomForCausalLM, BloomTokenizerFast),
|
33 |
+
"chatglm": (AutoModel, AutoTokenizer),
|
34 |
+
"llama": (LlamaForCausalLM, LlamaTokenizer),
|
35 |
+
"baichuan": (AutoModelForCausalLM, AutoTokenizer),
|
36 |
+
"auto": (AutoModelForCausalLM, AutoTokenizer),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
parser = argparse.ArgumentParser()
|
42 |
+
parser.add_argument('--model_type', default=None, type=str, required=True)
|
43 |
+
parser.add_argument('--base_model_name_or_path', default=None, required=True, type=str,
|
44 |
+
help="Base model name or path")
|
45 |
+
parser.add_argument('--tokenizer_path', default=None, type=str,
|
46 |
+
help="Please specify tokenization path.")
|
47 |
+
parser.add_argument('--peft_model_path', default=None, required=True, type=str,
|
48 |
+
help="Please specify LoRA model to be merged.")
|
49 |
+
parser.add_argument('--resize_emb', action='store_true', help='Whether to resize model token embeddings')
|
50 |
+
parser.add_argument('--output_dir', default='./merged', type=str)
|
51 |
+
args = parser.parse_args()
|
52 |
+
print(args)
|
53 |
+
|
54 |
+
base_model_path = args.base_model_name_or_path
|
55 |
+
peft_model_path = args.peft_model_path
|
56 |
+
output_dir = args.output_dir
|
57 |
+
print(f"Base model: {base_model_path}")
|
58 |
+
print(f"LoRA model: {peft_model_path}")
|
59 |
+
peft_config = PeftConfig.from_pretrained(peft_model_path)
|
60 |
+
|
61 |
+
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
62 |
+
if peft_config.task_type == "SEQ_CLS":
|
63 |
+
print("Loading LoRA for sequence classification model")
|
64 |
+
if args.model_type == "chatglm":
|
65 |
+
raise ValueError("chatglm does not support sequence classification")
|
66 |
+
base_model = AutoModelForSequenceClassification.from_pretrained(
|
67 |
+
base_model_path,
|
68 |
+
load_in_8bit=False,
|
69 |
+
torch_dtype=torch.float16,
|
70 |
+
trust_remote_code=True,
|
71 |
+
device_map="auto",
|
72 |
+
)
|
73 |
+
else:
|
74 |
+
print("Loading LoRA for causal language model")
|
75 |
+
base_model = model_class.from_pretrained(
|
76 |
+
base_model_path,
|
77 |
+
load_in_8bit=False,
|
78 |
+
torch_dtype=torch.float16,
|
79 |
+
trust_remote_code=True,
|
80 |
+
device_map="auto",
|
81 |
+
)
|
82 |
+
if args.tokenizer_path:
|
83 |
+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_path, trust_remote_code=True)
|
84 |
+
else:
|
85 |
+
tokenizer = tokenizer_class.from_pretrained(peft_model_path, trust_remote_code=True)
|
86 |
+
if args.resize_emb:
|
87 |
+
base_model_token_size = base_model.get_input_embeddings().weight.size(0)
|
88 |
+
if base_model_token_size != len(tokenizer):
|
89 |
+
base_model.resize_token_embeddings(len(tokenizer))
|
90 |
+
print(f"Resize vocabulary size {base_model_token_size} to {len(tokenizer)}")
|
91 |
+
|
92 |
+
lora_model = PeftModel.from_pretrained(
|
93 |
+
base_model,
|
94 |
+
peft_model_path,
|
95 |
+
device_map="auto",
|
96 |
+
torch_dtype=torch.float16,
|
97 |
+
)
|
98 |
+
lora_model.eval()
|
99 |
+
print(f"Merging with merge_and_unload...")
|
100 |
+
base_model = lora_model.merge_and_unload()
|
101 |
+
|
102 |
+
print("Saving to Hugging Face format...")
|
103 |
+
tokenizer.save_pretrained(output_dir)
|
104 |
+
base_model.save_pretrained(output_dir)
|
105 |
+
print(f"Done! model saved to {output_dir}")
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
main()
|
MedicalGPT-main/merge_tokenizers.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
|
8 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
9 |
+
from transformers import LlamaTokenizer
|
10 |
+
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
|
11 |
+
import sentencepiece as spm
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
|
15 |
+
def is_chinese(uchar):
|
16 |
+
"""判断一个unicode是否是汉字"""
|
17 |
+
return '\u4e00' <= uchar <= '\u9fa5'
|
18 |
+
|
19 |
+
|
20 |
+
def is_chinese_string(string):
|
21 |
+
"""判断是否全为汉字"""
|
22 |
+
return all(is_chinese(c) for c in string)
|
23 |
+
|
24 |
+
|
25 |
+
def load_baichuan_vocab(vocab_file):
|
26 |
+
words = set()
|
27 |
+
with open(vocab_file, "r", encoding="utf-8") as f:
|
28 |
+
for line in f:
|
29 |
+
if line.strip():
|
30 |
+
words.add(line.strip().split()[0])
|
31 |
+
return words
|
32 |
+
|
33 |
+
|
34 |
+
def load_jieba_vocab(jieba_vocab_file):
|
35 |
+
# Read jieba vocab and sort by freq
|
36 |
+
with open(jieba_vocab_file, "r", encoding="utf-8") as f:
|
37 |
+
lines = f.readlines()
|
38 |
+
word_freqs = [line.strip().split() for line in lines]
|
39 |
+
word_freqs.sort(key=lambda x: int(x[1]), reverse=True)
|
40 |
+
return word_freqs
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument('--base_tokenizer_dir', default=None, type=str, required=True)
|
46 |
+
parser.add_argument('--domain_sp_model_file', default='./domain_sp.model', type=str)
|
47 |
+
parser.add_argument('--baichuan_vocab_file', default="data/vocab/baichuan_vocab.txt", type=str)
|
48 |
+
parser.add_argument('--add_jieba', action='store_true', help='Whether to add jieba vocab.')
|
49 |
+
parser.add_argument('--jieba_word_freq_file', default='data/vocab/word_freq.txt', type=str)
|
50 |
+
parser.add_argument('--jieba_word_size', default=20000, type=int)
|
51 |
+
|
52 |
+
args = parser.parse_args()
|
53 |
+
print(args)
|
54 |
+
|
55 |
+
# load
|
56 |
+
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
|
57 |
+
chinese_sp_model = spm.SentencePieceProcessor()
|
58 |
+
chinese_sp_model.Load(args.domain_sp_model_file)
|
59 |
+
|
60 |
+
llama_spm = sp_pb2_model.ModelProto()
|
61 |
+
llama_spm.ParseFromString(llama_tokenizer.sp_model.serialized_model_proto())
|
62 |
+
chinese_spm = sp_pb2_model.ModelProto()
|
63 |
+
chinese_spm.ParseFromString(chinese_sp_model.serialized_model_proto())
|
64 |
+
|
65 |
+
# print number of tokens
|
66 |
+
print(len(llama_tokenizer), len(chinese_sp_model))
|
67 |
+
print(llama_tokenizer.all_special_tokens)
|
68 |
+
print(llama_tokenizer.all_special_ids)
|
69 |
+
print(llama_tokenizer.special_tokens_map)
|
70 |
+
|
71 |
+
# Add Chinese tokens to LLaMA tokenizer
|
72 |
+
llama_spm_tokens_set = set(p.piece for p in llama_spm.pieces)
|
73 |
+
|
74 |
+
print(len(llama_spm_tokens_set))
|
75 |
+
print(f"Before:{len(llama_spm_tokens_set)}")
|
76 |
+
added_set = set()
|
77 |
+
for p in chinese_spm.pieces:
|
78 |
+
piece = p.piece
|
79 |
+
if piece not in llama_spm_tokens_set:
|
80 |
+
# print('picec', piece)
|
81 |
+
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
82 |
+
new_p.piece = piece
|
83 |
+
new_p.score = 0
|
84 |
+
llama_spm.pieces.append(new_p)
|
85 |
+
added_set.add(piece)
|
86 |
+
print(f"[add domain tokens]New model pieces: {len(llama_spm.pieces)}")
|
87 |
+
|
88 |
+
vocab = load_baichuan_vocab(args.baichuan_vocab_file)
|
89 |
+
print('baichuan vocab len:', len(vocab))
|
90 |
+
baichuan_vocab_set = set([i for i in vocab if is_chinese_string(i)])
|
91 |
+
print('baichuan chinese vocab size:', len(baichuan_vocab_set))
|
92 |
+
print('baichuan vocab head:', list(baichuan_vocab_set)[:10])
|
93 |
+
for p in baichuan_vocab_set:
|
94 |
+
piece = p
|
95 |
+
if piece not in llama_spm_tokens_set and piece not in added_set:
|
96 |
+
# print('baichuan picec', piece)
|
97 |
+
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
98 |
+
new_p.piece = piece
|
99 |
+
new_p.score = 0
|
100 |
+
llama_spm.pieces.append(new_p)
|
101 |
+
added_set.add(piece)
|
102 |
+
print(f"[add baichuan tokens]New model pieces: {len(llama_spm.pieces)}")
|
103 |
+
|
104 |
+
if args.add_jieba:
|
105 |
+
word_freqs = load_jieba_vocab(args.jieba_word_freq_file)
|
106 |
+
top_words = word_freqs[:args.jieba_word_size]
|
107 |
+
print('jieba top10 freq words:', top_words[:10])
|
108 |
+
jieba_vocab_set = set([i[0] for i in top_words if i])
|
109 |
+
print('jieba_vocab_set size:', len(jieba_vocab_set))
|
110 |
+
print('jieba_vocab head:', list(jieba_vocab_set)[:3])
|
111 |
+
for p in jieba_vocab_set:
|
112 |
+
piece = p
|
113 |
+
if piece not in llama_spm_tokens_set and piece not in added_set:
|
114 |
+
# print('jieba picec', piece)
|
115 |
+
new_p = sp_pb2_model.ModelProto().SentencePiece()
|
116 |
+
new_p.piece = piece
|
117 |
+
new_p.score = 0
|
118 |
+
llama_spm.pieces.append(new_p)
|
119 |
+
print(f"[add jieba tokens]New model pieces: {len(llama_spm.pieces)}")
|
120 |
+
|
121 |
+
# Save
|
122 |
+
output_sp_dir = 'merged_tokenizer_sp'
|
123 |
+
output_hf_dir = 'merged_tokenizer_hf' # the path to save Chinese-LLaMA tokenizer
|
124 |
+
os.makedirs(output_sp_dir, exist_ok=True)
|
125 |
+
with open(output_sp_dir + '/chinese_llama.model', 'wb') as f:
|
126 |
+
f.write(llama_spm.SerializeToString())
|
127 |
+
tokenizer = LlamaTokenizer(vocab_file=output_sp_dir + '/chinese_llama.model')
|
128 |
+
|
129 |
+
tokenizer.save_pretrained(output_hf_dir)
|
130 |
+
print(f"Chinese-LLaMA tokenizer has been saved to {output_hf_dir}")
|
131 |
+
|
132 |
+
# Test
|
133 |
+
llama_tokenizer = LlamaTokenizer.from_pretrained(args.base_tokenizer_dir)
|
134 |
+
chinese_llama_tokenizer = LlamaTokenizer.from_pretrained(output_hf_dir)
|
135 |
+
print(chinese_llama_tokenizer.all_special_tokens)
|
136 |
+
print(chinese_llama_tokenizer.all_special_ids)
|
137 |
+
print(chinese_llama_tokenizer.special_tokens_map)
|
138 |
+
print('old len:', len(llama_tokenizer), ' new len:', len(chinese_llama_tokenizer))
|
139 |
+
text = '''this is a test, hello world. thisisatesthelloworld,
|
140 |
+
慕容复来到河边,姑苏慕容氏在外面丢了人。
|
141 |
+
1号店一周岁了,我们一古脑儿买了10斤零食。
|
142 |
+
巴塞罗那足球俱乐部简称巴萨(Barça),是一家位于西班牙加泰罗尼亚巴塞罗那的足球俱乐部,于1899年由瑞士企业家胡安·甘伯所创立,世界球坛顶级足球俱乐部之一。俱乐部主场可容纳接近十万名观众,是全欧洲最大及世界第二大的足球场。
|
143 |
+
白日依山尽,黄河入海流。欲穷千里目,更上一层楼。'''
|
144 |
+
print("Test text:\n", text)
|
145 |
+
print(f"Tokenized by LLaMA tokenizer:{llama_tokenizer.tokenize(text)}")
|
146 |
+
print(f"Tokenized by Chinese-LLaMA tokenizer:{chinese_llama_tokenizer.tokenize(text)}")
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
main()
|
MedicalGPT-main/pretraining.py
ADDED
@@ -0,0 +1,678 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright 2023 XuMing(xuming624@qq.com) and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
17 |
+
|
18 |
+
part of this code is adapted from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
|
19 |
+
"""
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
from dataclasses import dataclass, field
|
23 |
+
from glob import glob
|
24 |
+
from itertools import chain
|
25 |
+
from typing import Optional, List, Dict, Any, Mapping
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
from datasets import load_dataset
|
30 |
+
from loguru import logger
|
31 |
+
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
|
32 |
+
from sklearn.metrics import accuracy_score
|
33 |
+
from transformers import (
|
34 |
+
AutoConfig,
|
35 |
+
BloomForCausalLM,
|
36 |
+
AutoModelForCausalLM,
|
37 |
+
AutoModel,
|
38 |
+
LlamaTokenizer,
|
39 |
+
LlamaForCausalLM,
|
40 |
+
BloomTokenizerFast,
|
41 |
+
AutoTokenizer,
|
42 |
+
HfArgumentParser,
|
43 |
+
Trainer,
|
44 |
+
TrainingArguments,
|
45 |
+
is_torch_tpu_available,
|
46 |
+
set_seed,
|
47 |
+
)
|
48 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
49 |
+
from transformers.utils.versions import require_version
|
50 |
+
|
51 |
+
MODEL_CLASSES = {
|
52 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
53 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
54 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
55 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
56 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
57 |
+
}
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class ModelArguments:
|
62 |
+
"""
|
63 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
64 |
+
"""
|
65 |
+
|
66 |
+
model_type: str = field(
|
67 |
+
default=None,
|
68 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
69 |
+
)
|
70 |
+
model_name_or_path: Optional[str] = field(
|
71 |
+
default=None,
|
72 |
+
metadata={
|
73 |
+
"help": (
|
74 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
75 |
+
)
|
76 |
+
},
|
77 |
+
)
|
78 |
+
tokenizer_name_or_path: Optional[str] = field(
|
79 |
+
default=None,
|
80 |
+
metadata={
|
81 |
+
"help": (
|
82 |
+
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
83 |
+
)
|
84 |
+
},
|
85 |
+
)
|
86 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
87 |
+
cache_dir: Optional[str] = field(
|
88 |
+
default=None,
|
89 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
90 |
+
)
|
91 |
+
use_fast_tokenizer: bool = field(
|
92 |
+
default=False,
|
93 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
94 |
+
)
|
95 |
+
torch_dtype: Optional[str] = field(
|
96 |
+
default=None,
|
97 |
+
metadata={
|
98 |
+
"help": (
|
99 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
100 |
+
"dtype will be automatically derived from the model's weights."
|
101 |
+
),
|
102 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
103 |
+
},
|
104 |
+
)
|
105 |
+
device_map: Optional[str] = field(
|
106 |
+
default="auto",
|
107 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
108 |
+
)
|
109 |
+
trust_remote_code: bool = field(
|
110 |
+
default=True,
|
111 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
112 |
+
)
|
113 |
+
|
114 |
+
def __post_init__(self):
|
115 |
+
if self.model_type is None:
|
116 |
+
raise ValueError(
|
117 |
+
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
|
118 |
+
MODEL_CLASSES.keys()))
|
119 |
+
if self.model_name_or_path is None:
|
120 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
121 |
+
|
122 |
+
|
123 |
+
@dataclass
|
124 |
+
class DataTrainingArguments:
|
125 |
+
"""
|
126 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
127 |
+
"""
|
128 |
+
|
129 |
+
dataset_name: Optional[str] = field(
|
130 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
131 |
+
)
|
132 |
+
dataset_config_name: Optional[str] = field(
|
133 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
134 |
+
)
|
135 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train text data file folder."})
|
136 |
+
validation_file_dir: Optional[str] = field(
|
137 |
+
default=None,
|
138 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on text file folder."},
|
139 |
+
)
|
140 |
+
max_train_samples: Optional[int] = field(
|
141 |
+
default=None,
|
142 |
+
metadata={
|
143 |
+
"help": (
|
144 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
145 |
+
"value if set."
|
146 |
+
)
|
147 |
+
},
|
148 |
+
)
|
149 |
+
max_eval_samples: Optional[int] = field(
|
150 |
+
default=None,
|
151 |
+
metadata={
|
152 |
+
"help": (
|
153 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
154 |
+
"value if set."
|
155 |
+
)
|
156 |
+
},
|
157 |
+
)
|
158 |
+
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
159 |
+
block_size: Optional[int] = field(
|
160 |
+
default=1024,
|
161 |
+
metadata={
|
162 |
+
"help": (
|
163 |
+
"Optional input sequence length after tokenization. "
|
164 |
+
"The training dataset will be truncated in block of this size for training. "
|
165 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
166 |
+
)
|
167 |
+
},
|
168 |
+
)
|
169 |
+
overwrite_cache: bool = field(
|
170 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
171 |
+
)
|
172 |
+
validation_split_percentage: Optional[int] = field(
|
173 |
+
default=1,
|
174 |
+
metadata={
|
175 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
176 |
+
},
|
177 |
+
)
|
178 |
+
preprocessing_num_workers: Optional[int] = field(
|
179 |
+
default=None,
|
180 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
181 |
+
)
|
182 |
+
keep_linebreaks: bool = field(
|
183 |
+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
184 |
+
)
|
185 |
+
|
186 |
+
def __post_init__(self):
|
187 |
+
if self.streaming:
|
188 |
+
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
189 |
+
|
190 |
+
|
191 |
+
@dataclass
|
192 |
+
class PeftArguments(TrainingArguments):
|
193 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
194 |
+
target_modules: Optional[str] = field(default="all")
|
195 |
+
lora_rank: Optional[int] = field(default=8)
|
196 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
197 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
198 |
+
modules_to_save: Optional[str] = field(default=None)
|
199 |
+
peft_path: Optional[str] = field(default=None)
|
200 |
+
|
201 |
+
|
202 |
+
def accuracy(predictions, references, normalize=True, sample_weight=None):
|
203 |
+
return {
|
204 |
+
"accuracy": float(accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight))
|
205 |
+
}
|
206 |
+
|
207 |
+
|
208 |
+
def compute_metrics(eval_preds):
|
209 |
+
preds, labels = eval_preds
|
210 |
+
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
211 |
+
# by preprocess_logits_for_metrics, we need to shift the labels
|
212 |
+
labels = labels[:, 1:].reshape(-1)
|
213 |
+
preds = preds[:, :-1].reshape(-1)
|
214 |
+
return accuracy(predictions=preds, references=labels)
|
215 |
+
|
216 |
+
|
217 |
+
def preprocess_logits_for_metrics(logits, labels):
|
218 |
+
if isinstance(logits, tuple):
|
219 |
+
# Depending on the model and config, logits may contain extra tensors,
|
220 |
+
# like past_key_values, but logits always come first
|
221 |
+
logits = logits[0]
|
222 |
+
return logits.argmax(dim=-1)
|
223 |
+
|
224 |
+
|
225 |
+
def fault_tolerance_data_collator(features: List) -> Dict[str, Any]:
|
226 |
+
if not isinstance(features[0], Mapping):
|
227 |
+
features = [vars(f) for f in features]
|
228 |
+
first = features[0]
|
229 |
+
batch = {}
|
230 |
+
|
231 |
+
# Special handling for labels.
|
232 |
+
# Ensure that tensor is created with the correct type
|
233 |
+
if "label" in first and first["label"] is not None:
|
234 |
+
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
|
235 |
+
dtype = torch.long if isinstance(label, int) else torch.float
|
236 |
+
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
|
237 |
+
elif "label_ids" in first and first["label_ids"] is not None:
|
238 |
+
if isinstance(first["label_ids"], torch.Tensor):
|
239 |
+
batch["labels"] = torch.stack([f["label_ids"] for f in features])
|
240 |
+
else:
|
241 |
+
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
|
242 |
+
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
|
243 |
+
|
244 |
+
# Handling of all other possible keys.
|
245 |
+
# Again, we will use the first element to figure out which key/values are not None for this model.
|
246 |
+
try:
|
247 |
+
for k, v in first.items():
|
248 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
249 |
+
if isinstance(v, torch.Tensor):
|
250 |
+
batch[k] = torch.stack([f[k] for f in features])
|
251 |
+
elif isinstance(v, np.ndarray):
|
252 |
+
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
|
253 |
+
else:
|
254 |
+
batch[k] = torch.tensor([f[k] for f in features])
|
255 |
+
except ValueError: # quick fix by simply take the first example
|
256 |
+
for k, v in first.items():
|
257 |
+
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
258 |
+
if isinstance(v, torch.Tensor):
|
259 |
+
batch[k] = torch.stack([features[0][k]] * len(features))
|
260 |
+
elif isinstance(v, np.ndarray):
|
261 |
+
batch[k] = torch.tensor(np.stack([features[0][k]] * len(features)))
|
262 |
+
else:
|
263 |
+
batch[k] = torch.tensor([features[0][k]] * len(features))
|
264 |
+
|
265 |
+
return batch
|
266 |
+
|
267 |
+
|
268 |
+
class GroupTextsBuilder:
|
269 |
+
def __init__(self, max_seq_length):
|
270 |
+
self.max_seq_length = max_seq_length
|
271 |
+
|
272 |
+
def __call__(self, examples):
|
273 |
+
# Concatenate all texts.
|
274 |
+
firsts = {k: examples[k][0][0] for k in examples.keys()}
|
275 |
+
lasts = {k: examples[k][0][-1] for k in examples.keys()}
|
276 |
+
contents = {k: sum([vi[1:-1] for vi in v], []) for k, v in examples.items()}
|
277 |
+
total_length = len(contents[list(examples.keys())[0]])
|
278 |
+
|
279 |
+
content_length = self.max_seq_length - 2
|
280 |
+
if total_length >= content_length:
|
281 |
+
total_length = (total_length // content_length) * content_length
|
282 |
+
# Split by chunks of max_len.
|
283 |
+
result = {
|
284 |
+
k: [[firsts[k]] + t[i: i + content_length] + [lasts[k]] for i in range(0, total_length, content_length)] for
|
285 |
+
k, t in contents.items()}
|
286 |
+
return result
|
287 |
+
|
288 |
+
|
289 |
+
class SavePeftModelTrainer(Trainer):
|
290 |
+
"""
|
291 |
+
Trainer for lora models
|
292 |
+
"""
|
293 |
+
|
294 |
+
def save_model(self, output_dir=None, _internal_call=False):
|
295 |
+
"""Save the LoRA model."""
|
296 |
+
os.makedirs(output_dir, exist_ok=True)
|
297 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
298 |
+
self.model.save_pretrained(output_dir)
|
299 |
+
|
300 |
+
|
301 |
+
def save_model(output_dir, model, tokenizer, args):
|
302 |
+
"""Save the model and the tokenizer."""
|
303 |
+
os.makedirs(output_dir, exist_ok=True)
|
304 |
+
|
305 |
+
# Take care of distributed/parallel training
|
306 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
307 |
+
model_to_save.save_pretrained(output_dir)
|
308 |
+
tokenizer.save_pretrained(output_dir)
|
309 |
+
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
310 |
+
|
311 |
+
|
312 |
+
def print_trainable_parameters(model):
|
313 |
+
"""
|
314 |
+
Prints the number of trainable parameters in the model.
|
315 |
+
"""
|
316 |
+
trainable_params = 0
|
317 |
+
all_param = 0
|
318 |
+
for _, param in model.named_parameters():
|
319 |
+
all_param += param.numel()
|
320 |
+
if param.requires_grad:
|
321 |
+
trainable_params += param.numel()
|
322 |
+
print(
|
323 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
324 |
+
)
|
325 |
+
|
326 |
+
|
327 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
328 |
+
"""Find all linear layer names in the model. reference from qlora paper."""
|
329 |
+
cls = torch.nn.Linear
|
330 |
+
if int4 or int8:
|
331 |
+
import bitsandbytes as bnb
|
332 |
+
if int4:
|
333 |
+
cls = bnb.nn.Linear4bit
|
334 |
+
elif int8:
|
335 |
+
cls = bnb.nn.Linear8bitLt
|
336 |
+
lora_module_names = set()
|
337 |
+
for name, module in peft_model.named_modules():
|
338 |
+
if isinstance(module, cls):
|
339 |
+
# last layer is not add to lora_module_names
|
340 |
+
if 'lm_head' in name:
|
341 |
+
continue
|
342 |
+
names = name.split('.')
|
343 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
344 |
+
return sorted(lora_module_names)
|
345 |
+
|
346 |
+
|
347 |
+
def main():
|
348 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
|
349 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
350 |
+
|
351 |
+
logger.info(f"Model args: {model_args}")
|
352 |
+
logger.info(f"Data args: {data_args}")
|
353 |
+
logger.info(f"Training args: {training_args}")
|
354 |
+
logger.info(
|
355 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
356 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
357 |
+
)
|
358 |
+
|
359 |
+
# Set seed before initializing model.
|
360 |
+
set_seed(training_args.seed)
|
361 |
+
|
362 |
+
# Load tokenizer
|
363 |
+
if not model_args.model_type:
|
364 |
+
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
|
365 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
|
366 |
+
|
367 |
+
tokenizer_kwargs = {
|
368 |
+
"cache_dir": model_args.cache_dir,
|
369 |
+
"use_fast": model_args.use_fast_tokenizer,
|
370 |
+
"trust_remote_code": model_args.trust_remote_code,
|
371 |
+
}
|
372 |
+
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
373 |
+
if not tokenizer_name_or_path:
|
374 |
+
tokenizer_name_or_path = model_args.model_name_or_path
|
375 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
376 |
+
|
377 |
+
# Preprocessing the datasets.
|
378 |
+
def tokenize_function(examples):
|
379 |
+
return tokenizer(examples["text"])
|
380 |
+
|
381 |
+
if data_args.block_size is None:
|
382 |
+
block_size = tokenizer.model_max_length
|
383 |
+
if block_size > 2048:
|
384 |
+
logger.warning(
|
385 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
386 |
+
" of 2048. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
387 |
+
" override this default with `--block_size xxx`."
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
if data_args.block_size > tokenizer.model_max_length:
|
391 |
+
logger.warning(
|
392 |
+
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
393 |
+
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
394 |
+
)
|
395 |
+
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
396 |
+
|
397 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
398 |
+
def group_texts(examples):
|
399 |
+
# Concatenate all texts.
|
400 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
401 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
402 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
403 |
+
# customize this part to your needs.
|
404 |
+
if total_length >= block_size:
|
405 |
+
total_length = (total_length // block_size) * block_size
|
406 |
+
# Split by chunks of max_len.
|
407 |
+
result = {
|
408 |
+
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
409 |
+
for k, t in concatenated_examples.items()
|
410 |
+
}
|
411 |
+
result["labels"] = result["input_ids"].copy()
|
412 |
+
return result
|
413 |
+
|
414 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
415 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
416 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
417 |
+
#
|
418 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
419 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
420 |
+
#
|
421 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
422 |
+
# download the dataset.
|
423 |
+
if data_args.dataset_name is not None:
|
424 |
+
# Downloading and loading a dataset from the hub.
|
425 |
+
raw_datasets = load_dataset(
|
426 |
+
data_args.dataset_name,
|
427 |
+
data_args.dataset_config_name,
|
428 |
+
cache_dir=model_args.cache_dir,
|
429 |
+
streaming=data_args.streaming,
|
430 |
+
)
|
431 |
+
if "validation" not in raw_datasets.keys():
|
432 |
+
raw_datasets["validation"] = load_dataset(
|
433 |
+
data_args.dataset_name,
|
434 |
+
data_args.dataset_config_name,
|
435 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
436 |
+
cache_dir=model_args.cache_dir,
|
437 |
+
streaming=data_args.streaming,
|
438 |
+
)
|
439 |
+
raw_datasets["train"] = load_dataset(
|
440 |
+
data_args.dataset_name,
|
441 |
+
data_args.dataset_config_name,
|
442 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
443 |
+
cache_dir=model_args.cache_dir,
|
444 |
+
streaming=data_args.streaming,
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
data_files = {}
|
448 |
+
dataset_args = {}
|
449 |
+
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
|
450 |
+
train_data_files = glob(f'{data_args.train_file_dir}/**/*.txt', recursive=True) + glob(
|
451 |
+
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
452 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
453 |
+
logger.info(f"train files: {train_data_files}")
|
454 |
+
# Train data files must be same type, e.g. all txt or all jsonl
|
455 |
+
types = [f.split('.')[-1] for f in train_data_files]
|
456 |
+
if len(set(types)) > 1:
|
457 |
+
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
|
458 |
+
data_files["train"] = train_data_files
|
459 |
+
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
|
460 |
+
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.txt', recursive=True) + glob(
|
461 |
+
f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
462 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
463 |
+
logger.info(f"eval files: {eval_data_files}")
|
464 |
+
data_files["validation"] = eval_data_files
|
465 |
+
# Train data files must be same type, e.g. all txt or all jsonl
|
466 |
+
types = [f.split('.')[-1] for f in eval_data_files]
|
467 |
+
if len(set(types)) > 1:
|
468 |
+
raise ValueError(f"train files must be same type, e.g. all txt or all jsonl, but got {types}")
|
469 |
+
extension = "text" if data_files["train"][0].endswith('txt') else 'json'
|
470 |
+
if extension == "text":
|
471 |
+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
472 |
+
raw_datasets = load_dataset(
|
473 |
+
extension,
|
474 |
+
data_files=data_files,
|
475 |
+
cache_dir=model_args.cache_dir,
|
476 |
+
**dataset_args,
|
477 |
+
)
|
478 |
+
|
479 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
480 |
+
if "validation" not in raw_datasets.keys():
|
481 |
+
raw_datasets["validation"] = load_dataset(
|
482 |
+
extension,
|
483 |
+
data_files=data_files,
|
484 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
485 |
+
cache_dir=model_args.cache_dir,
|
486 |
+
**dataset_args,
|
487 |
+
)
|
488 |
+
raw_datasets["train"] = load_dataset(
|
489 |
+
extension,
|
490 |
+
data_files=data_files,
|
491 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
492 |
+
cache_dir=model_args.cache_dir,
|
493 |
+
**dataset_args,
|
494 |
+
)
|
495 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
496 |
+
|
497 |
+
# Preprocessing the datasets.
|
498 |
+
if training_args.do_train:
|
499 |
+
column_names = list(raw_datasets["train"].features)
|
500 |
+
else:
|
501 |
+
column_names = list(raw_datasets["validation"].features)
|
502 |
+
|
503 |
+
with training_args.main_process_first(desc="Dataset tokenization and grouping"):
|
504 |
+
if not data_args.streaming:
|
505 |
+
tokenized_datasets = raw_datasets.map(
|
506 |
+
tokenize_function,
|
507 |
+
batched=True,
|
508 |
+
num_proc=data_args.preprocessing_num_workers,
|
509 |
+
remove_columns=column_names,
|
510 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
511 |
+
desc="Running tokenizer on dataset",
|
512 |
+
)
|
513 |
+
lm_datasets = tokenized_datasets.map(
|
514 |
+
group_texts,
|
515 |
+
batched=True,
|
516 |
+
num_proc=data_args.preprocessing_num_workers,
|
517 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
518 |
+
desc=f"Grouping texts in chunks of {block_size}",
|
519 |
+
)
|
520 |
+
else:
|
521 |
+
tokenized_datasets = raw_datasets.map(
|
522 |
+
tokenize_function,
|
523 |
+
batched=True,
|
524 |
+
remove_columns=column_names,
|
525 |
+
)
|
526 |
+
lm_datasets = tokenized_datasets.map(
|
527 |
+
group_texts,
|
528 |
+
batched=True,
|
529 |
+
)
|
530 |
+
|
531 |
+
train_dataset = None
|
532 |
+
max_train_samples = 0
|
533 |
+
if training_args.do_train:
|
534 |
+
if "train" not in tokenized_datasets:
|
535 |
+
raise ValueError("--do_train requires a train dataset")
|
536 |
+
train_dataset = lm_datasets['train']
|
537 |
+
max_train_samples = len(train_dataset)
|
538 |
+
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
|
539 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
540 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
541 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
542 |
+
logger.debug("Tokenized training example:")
|
543 |
+
logger.debug(tokenizer.decode(train_dataset[0]['input_ids']))
|
544 |
+
|
545 |
+
eval_dataset = None
|
546 |
+
max_eval_samples = 0
|
547 |
+
if training_args.do_eval:
|
548 |
+
if "validation" not in tokenized_datasets:
|
549 |
+
raise ValueError("--do_eval requires a validation dataset")
|
550 |
+
eval_dataset = lm_datasets["validation"]
|
551 |
+
max_eval_samples = len(eval_dataset)
|
552 |
+
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
|
553 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
554 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
555 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
556 |
+
logger.debug("Tokenized eval example:")
|
557 |
+
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
|
558 |
+
|
559 |
+
# Load model
|
560 |
+
if model_args.model_type and model_args.model_name_or_path:
|
561 |
+
torch_dtype = (
|
562 |
+
model_args.torch_dtype
|
563 |
+
if model_args.torch_dtype in ["auto", None]
|
564 |
+
else getattr(torch, model_args.torch_dtype)
|
565 |
+
)
|
566 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
567 |
+
ddp = world_size != 1
|
568 |
+
if ddp:
|
569 |
+
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
570 |
+
|
571 |
+
config = config_class.from_pretrained(
|
572 |
+
model_args.model_name_or_path,
|
573 |
+
torch_dtype=torch_dtype,
|
574 |
+
trust_remote_code=model_args.trust_remote_code,
|
575 |
+
cache_dir=model_args.cache_dir
|
576 |
+
)
|
577 |
+
model = model_class.from_pretrained(
|
578 |
+
model_args.model_name_or_path,
|
579 |
+
config=config,
|
580 |
+
load_in_8bit=model_args.load_in_8bit,
|
581 |
+
device_map=model_args.device_map,
|
582 |
+
trust_remote_code=model_args.trust_remote_code,
|
583 |
+
)
|
584 |
+
else:
|
585 |
+
raise ValueError(f"Error, model_name_or_path is None, Continue PT must be loaded from a pre-trained model")
|
586 |
+
|
587 |
+
if training_args.use_peft:
|
588 |
+
if training_args.peft_path is not None:
|
589 |
+
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
|
590 |
+
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
|
591 |
+
else:
|
592 |
+
logger.info("Init new peft model")
|
593 |
+
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
|
594 |
+
if target_modules and 'all' in target_modules:
|
595 |
+
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
|
596 |
+
modules_to_save = training_args.modules_to_save
|
597 |
+
if modules_to_save is not None:
|
598 |
+
modules_to_save = modules_to_save.split(',')
|
599 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
600 |
+
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
|
601 |
+
peft_config = LoraConfig(
|
602 |
+
task_type=TaskType.CAUSAL_LM,
|
603 |
+
target_modules=target_modules,
|
604 |
+
inference_mode=False,
|
605 |
+
r=training_args.lora_rank,
|
606 |
+
lora_alpha=training_args.lora_alpha,
|
607 |
+
lora_dropout=training_args.lora_dropout,
|
608 |
+
modules_to_save=modules_to_save)
|
609 |
+
model = get_peft_model(model, peft_config)
|
610 |
+
if model_args.load_in_8bit:
|
611 |
+
model = prepare_model_for_int8_training(model)
|
612 |
+
model.print_trainable_parameters()
|
613 |
+
else:
|
614 |
+
logger.info("Full parameters training")
|
615 |
+
model = model.float()
|
616 |
+
print_trainable_parameters(model)
|
617 |
+
|
618 |
+
# Initialize our Trainer
|
619 |
+
if training_args.gradient_checkpointing:
|
620 |
+
model.gradient_checkpointing_enable()
|
621 |
+
model.config.use_cache = False
|
622 |
+
else:
|
623 |
+
model.config.use_cache = True
|
624 |
+
model.enable_input_require_grads()
|
625 |
+
if not ddp and torch.cuda.device_count() > 1:
|
626 |
+
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
627 |
+
model.is_parallelizable = True
|
628 |
+
model.model_parallel = True
|
629 |
+
|
630 |
+
trainer = SavePeftModelTrainer(
|
631 |
+
model=model,
|
632 |
+
args=training_args,
|
633 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
634 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
635 |
+
tokenizer=tokenizer,
|
636 |
+
data_collator=fault_tolerance_data_collator,
|
637 |
+
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
638 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
639 |
+
if training_args.do_eval and not is_torch_tpu_available()
|
640 |
+
else None,
|
641 |
+
)
|
642 |
+
|
643 |
+
# Training
|
644 |
+
if training_args.do_train:
|
645 |
+
logger.info("*** Train ***")
|
646 |
+
logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}")
|
647 |
+
checkpoint = None
|
648 |
+
if training_args.resume_from_checkpoint is not None:
|
649 |
+
checkpoint = training_args.resume_from_checkpoint
|
650 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
651 |
+
|
652 |
+
metrics = train_result.metrics
|
653 |
+
metrics["train_samples"] = max_train_samples
|
654 |
+
logger.debug(f"Training metrics: {metrics}")
|
655 |
+
trainer.log_metrics("train", metrics)
|
656 |
+
trainer.save_metrics("train", metrics)
|
657 |
+
trainer.save_state()
|
658 |
+
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
|
659 |
+
save_model(training_args.output_dir, model, tokenizer, training_args)
|
660 |
+
|
661 |
+
# Evaluation
|
662 |
+
if training_args.do_eval and trainer.is_world_process_zero():
|
663 |
+
logger.info("*** Evaluate ***")
|
664 |
+
metrics = trainer.evaluate()
|
665 |
+
|
666 |
+
metrics["eval_samples"] = max_eval_samples
|
667 |
+
try:
|
668 |
+
perplexity = math.exp(metrics["eval_loss"])
|
669 |
+
except OverflowError:
|
670 |
+
perplexity = float("inf")
|
671 |
+
metrics["perplexity"] = perplexity
|
672 |
+
logger.debug(f"Eval metrics: {metrics}")
|
673 |
+
trainer.log_metrics("eval", metrics)
|
674 |
+
trainer.save_metrics("eval", metrics)
|
675 |
+
|
676 |
+
|
677 |
+
if __name__ == "__main__":
|
678 |
+
main()
|
MedicalGPT-main/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
loguru
|
2 |
+
transformers>=4.30.1
|
3 |
+
sentencepiece
|
4 |
+
datasets
|
5 |
+
tqdm
|
6 |
+
tensorboard
|
7 |
+
tqdm>=4.47.0
|
8 |
+
peft>=0.5.0
|
9 |
+
accelerate>=0.20.3
|
10 |
+
trl>=0.6.0
|
MedicalGPT-main/reward_modeling.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description:
|
5 |
+
"""
|
6 |
+
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from glob import glob
|
11 |
+
from typing import Any, List, Union, Optional, Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from datasets import load_dataset
|
15 |
+
from loguru import logger
|
16 |
+
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
|
17 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
from transformers import (
|
20 |
+
AutoConfig,
|
21 |
+
PreTrainedTokenizerBase,
|
22 |
+
BloomForSequenceClassification,
|
23 |
+
LlamaForSequenceClassification,
|
24 |
+
LlamaTokenizer,
|
25 |
+
BloomTokenizerFast,
|
26 |
+
AlbertForSequenceClassification,
|
27 |
+
BertForSequenceClassification,
|
28 |
+
BertTokenizer,
|
29 |
+
AutoTokenizer,
|
30 |
+
RobertaForSequenceClassification,
|
31 |
+
AutoModelForSequenceClassification,
|
32 |
+
RobertaTokenizer,
|
33 |
+
HfArgumentParser,
|
34 |
+
Trainer,
|
35 |
+
TrainingArguments,
|
36 |
+
set_seed,
|
37 |
+
)
|
38 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
39 |
+
|
40 |
+
MODEL_CLASSES = {
|
41 |
+
"bert": (AutoConfig, BertForSequenceClassification, BertTokenizer),
|
42 |
+
"roberta": (AutoConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
43 |
+
"albert": (AutoConfig, AlbertForSequenceClassification, AutoTokenizer),
|
44 |
+
"bloom": (AutoConfig, BloomForSequenceClassification, BloomTokenizerFast),
|
45 |
+
"llama": (AutoConfig, LlamaForSequenceClassification, LlamaTokenizer),
|
46 |
+
"auto": (AutoConfig, AutoModelForSequenceClassification, AutoTokenizer),
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
@dataclass
|
51 |
+
class ModelArguments:
|
52 |
+
"""
|
53 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
54 |
+
"""
|
55 |
+
|
56 |
+
model_type: str = field(
|
57 |
+
default=None,
|
58 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
59 |
+
)
|
60 |
+
model_name_or_path: Optional[str] = field(
|
61 |
+
default=None,
|
62 |
+
metadata={
|
63 |
+
"help": (
|
64 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
65 |
+
)
|
66 |
+
},
|
67 |
+
)
|
68 |
+
tokenizer_name_or_path: Optional[str] = field(
|
69 |
+
default=None,
|
70 |
+
metadata={
|
71 |
+
"help": (
|
72 |
+
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
73 |
+
)
|
74 |
+
},
|
75 |
+
)
|
76 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
77 |
+
cache_dir: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
80 |
+
)
|
81 |
+
use_fast_tokenizer: bool = field(
|
82 |
+
default=False,
|
83 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
84 |
+
)
|
85 |
+
torch_dtype: Optional[str] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": (
|
89 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
90 |
+
"dtype will be automatically derived from the model's weights."
|
91 |
+
),
|
92 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
93 |
+
},
|
94 |
+
)
|
95 |
+
device_map: Optional[str] = field(
|
96 |
+
default="auto",
|
97 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
98 |
+
)
|
99 |
+
trust_remote_code: bool = field(
|
100 |
+
default=True,
|
101 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
102 |
+
)
|
103 |
+
|
104 |
+
def __post_init__(self):
|
105 |
+
if self.model_type is None:
|
106 |
+
raise ValueError(
|
107 |
+
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
|
108 |
+
MODEL_CLASSES.keys()))
|
109 |
+
if self.model_name_or_path is None:
|
110 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class DataTrainingArguments:
|
115 |
+
"""
|
116 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
117 |
+
"""
|
118 |
+
|
119 |
+
dataset_name: Optional[str] = field(
|
120 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
121 |
+
)
|
122 |
+
dataset_config_name: Optional[str] = field(
|
123 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
124 |
+
)
|
125 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
|
126 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
|
127 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
128 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
129 |
+
max_train_samples: Optional[int] = field(
|
130 |
+
default=None,
|
131 |
+
metadata={
|
132 |
+
"help": (
|
133 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
134 |
+
"value if set."
|
135 |
+
)
|
136 |
+
},
|
137 |
+
)
|
138 |
+
max_eval_samples: Optional[int] = field(
|
139 |
+
default=None,
|
140 |
+
metadata={
|
141 |
+
"help": (
|
142 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
143 |
+
"value if set."
|
144 |
+
)
|
145 |
+
},
|
146 |
+
)
|
147 |
+
overwrite_cache: bool = field(
|
148 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
149 |
+
)
|
150 |
+
validation_split_percentage: Optional[int] = field(
|
151 |
+
default=1,
|
152 |
+
metadata={
|
153 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
154 |
+
},
|
155 |
+
)
|
156 |
+
preprocessing_num_workers: Optional[int] = field(
|
157 |
+
default=4,
|
158 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
159 |
+
)
|
160 |
+
|
161 |
+
|
162 |
+
@dataclass
|
163 |
+
class PeftArguments(TrainingArguments):
|
164 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
165 |
+
target_modules: Optional[str] = field(default="all")
|
166 |
+
lora_rank: Optional[int] = field(default=8)
|
167 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
168 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
169 |
+
modules_to_save: Optional[str] = field(default=None)
|
170 |
+
peft_path: Optional[str] = field(default=None)
|
171 |
+
|
172 |
+
|
173 |
+
def compute_metrics(eval_preds):
|
174 |
+
preds, labels = eval_preds
|
175 |
+
# Here, predictions is rewards_chosen and rewards_rejected.
|
176 |
+
if isinstance(preds, torch.Tensor):
|
177 |
+
preds = preds.detach().cpu().numpy()
|
178 |
+
if isinstance(labels, torch.Tensor):
|
179 |
+
labels = labels.detach().cpu().numpy()
|
180 |
+
# MSE
|
181 |
+
mse = mean_squared_error(labels, preds)
|
182 |
+
# MAE
|
183 |
+
mae = mean_absolute_error(labels, preds)
|
184 |
+
|
185 |
+
return {"mse": mse, "mae": mae}
|
186 |
+
|
187 |
+
|
188 |
+
@dataclass
|
189 |
+
class RewardDataCollatorWithPadding:
|
190 |
+
"""We need to define a special data collator that batches the data in our chosen vs rejected format"""
|
191 |
+
tokenizer: PreTrainedTokenizerBase
|
192 |
+
padding: Union[bool, str] = True
|
193 |
+
max_length: Optional[int] = None
|
194 |
+
pad_to_multiple_of: Optional[int] = None
|
195 |
+
return_tensors: str = "pt"
|
196 |
+
|
197 |
+
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
198 |
+
features_chosen = []
|
199 |
+
features_rejected = []
|
200 |
+
for feature in features:
|
201 |
+
features_chosen.append(
|
202 |
+
{
|
203 |
+
"input_ids": feature["input_ids_chosen"],
|
204 |
+
"attention_mask": feature["attention_mask_chosen"],
|
205 |
+
}
|
206 |
+
)
|
207 |
+
features_rejected.append(
|
208 |
+
{
|
209 |
+
"input_ids": feature["input_ids_rejected"],
|
210 |
+
"attention_mask": feature["attention_mask_rejected"],
|
211 |
+
}
|
212 |
+
)
|
213 |
+
batch_chosen = self.tokenizer.pad(
|
214 |
+
features_chosen,
|
215 |
+
padding=self.padding,
|
216 |
+
max_length=self.max_length,
|
217 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
218 |
+
return_tensors=self.return_tensors,
|
219 |
+
)
|
220 |
+
batch_rejected = self.tokenizer.pad(
|
221 |
+
features_rejected,
|
222 |
+
padding=self.padding,
|
223 |
+
max_length=self.max_length,
|
224 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
225 |
+
return_tensors=self.return_tensors,
|
226 |
+
)
|
227 |
+
batch = {
|
228 |
+
"input_ids_chosen": batch_chosen["input_ids"],
|
229 |
+
"attention_mask_chosen": batch_chosen["attention_mask"],
|
230 |
+
"input_ids_rejected": batch_rejected["input_ids"],
|
231 |
+
"attention_mask_rejected": batch_rejected["attention_mask"],
|
232 |
+
"return_loss": True,
|
233 |
+
}
|
234 |
+
return batch
|
235 |
+
|
236 |
+
|
237 |
+
class RewardTrainer(Trainer):
|
238 |
+
"""
|
239 |
+
Trainer for reward models
|
240 |
+
Define how to compute the reward loss. Use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
|
241 |
+
"""
|
242 |
+
|
243 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
244 |
+
rewards_chosen = model(input_ids=inputs["input_ids_chosen"],
|
245 |
+
attention_mask=inputs["attention_mask_chosen"])[0]
|
246 |
+
rewards_rejected = model(input_ids=inputs["input_ids_rejected"],
|
247 |
+
attention_mask=inputs["attention_mask_rejected"])[0]
|
248 |
+
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
249 |
+
if return_outputs:
|
250 |
+
return loss, {"rewards_chosen": rewards_chosen, "rewards_rejected": rewards_rejected}
|
251 |
+
return loss
|
252 |
+
|
253 |
+
def evaluate(
|
254 |
+
self,
|
255 |
+
eval_dataset: Optional[Dataset] = None,
|
256 |
+
ignore_keys: Optional[List[str]] = None,
|
257 |
+
metric_key_prefix: str = "eval",
|
258 |
+
) -> Dict[str, float]:
|
259 |
+
if eval_dataset is None:
|
260 |
+
eval_dataset = self.eval_dataset
|
261 |
+
return super().evaluate(eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
|
262 |
+
|
263 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
264 |
+
# Prepare inputs for chosen and rejected separately
|
265 |
+
device = model.device
|
266 |
+
|
267 |
+
inputs_chosen = {
|
268 |
+
"input_ids": inputs["input_ids_chosen"].to(device),
|
269 |
+
"attention_mask": inputs["attention_mask_chosen"].to(device),
|
270 |
+
}
|
271 |
+
outputs_chosen = model(**inputs_chosen)
|
272 |
+
rewards_chosen = outputs_chosen.logits.detach()
|
273 |
+
|
274 |
+
inputs_rejected = {
|
275 |
+
"input_ids": inputs["input_ids_rejected"].to(device),
|
276 |
+
"attention_mask": inputs["attention_mask_rejected"].to(device),
|
277 |
+
}
|
278 |
+
outputs_rejected = model(**inputs_rejected)
|
279 |
+
rewards_rejected = outputs_rejected.logits.detach()
|
280 |
+
|
281 |
+
# Keep the compute_loss method
|
282 |
+
loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
283 |
+
if prediction_loss_only:
|
284 |
+
return (loss, None, None)
|
285 |
+
|
286 |
+
return (loss, rewards_chosen, rewards_rejected)
|
287 |
+
|
288 |
+
def save_model(self, output_dir=None, _internal_call=False):
|
289 |
+
"""Save the LoRA model."""
|
290 |
+
os.makedirs(output_dir, exist_ok=True)
|
291 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
292 |
+
self.model.save_pretrained(output_dir)
|
293 |
+
|
294 |
+
|
295 |
+
def save_model(output_dir, model, tokenizer, args):
|
296 |
+
"""Save the model and the tokenizer."""
|
297 |
+
os.makedirs(output_dir, exist_ok=True)
|
298 |
+
|
299 |
+
# Take care of distributed/parallel training
|
300 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
301 |
+
model_to_save.save_pretrained(output_dir)
|
302 |
+
tokenizer.save_pretrained(output_dir)
|
303 |
+
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
304 |
+
|
305 |
+
|
306 |
+
class CastOutputToFloat(torch.nn.Sequential):
|
307 |
+
"""Cast the output of the model to float"""
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
return super().forward(x).to(torch.float32)
|
311 |
+
|
312 |
+
|
313 |
+
def print_trainable_parameters(model):
|
314 |
+
"""
|
315 |
+
Prints the number of trainable parameters in the model.
|
316 |
+
"""
|
317 |
+
trainable_params = 0
|
318 |
+
all_param = 0
|
319 |
+
for _, param in model.named_parameters():
|
320 |
+
all_param += param.numel()
|
321 |
+
if param.requires_grad:
|
322 |
+
trainable_params += param.numel()
|
323 |
+
print(
|
324 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
329 |
+
cls = torch.nn.Linear
|
330 |
+
if int4 or int8:
|
331 |
+
import bitsandbytes as bnb
|
332 |
+
if int4:
|
333 |
+
cls = bnb.nn.Linear4bit
|
334 |
+
elif int8:
|
335 |
+
cls = bnb.nn.Linear8bitLt
|
336 |
+
lora_module_names = set()
|
337 |
+
for name, module in peft_model.named_modules():
|
338 |
+
if isinstance(module, cls):
|
339 |
+
# last layer is not add to lora_module_names
|
340 |
+
if 'lm_head' in name:
|
341 |
+
continue
|
342 |
+
if 'score' in name:
|
343 |
+
continue
|
344 |
+
names = name.split('.')
|
345 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
346 |
+
return sorted(lora_module_names)
|
347 |
+
|
348 |
+
|
349 |
+
def main():
|
350 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
|
351 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
352 |
+
|
353 |
+
logger.info(f"Model args: {model_args}")
|
354 |
+
logger.info(f"Data args: {data_args}")
|
355 |
+
logger.info(f"Training args: {training_args}")
|
356 |
+
logger.info(
|
357 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
358 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
359 |
+
)
|
360 |
+
|
361 |
+
# Set seed before initializing model.
|
362 |
+
set_seed(training_args.seed)
|
363 |
+
|
364 |
+
# Load model
|
365 |
+
if not model_args.model_type:
|
366 |
+
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
|
367 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
|
368 |
+
if model_args.model_name_or_path:
|
369 |
+
torch_dtype = (
|
370 |
+
model_args.torch_dtype
|
371 |
+
if model_args.torch_dtype in ["auto", None]
|
372 |
+
else getattr(torch, model_args.torch_dtype)
|
373 |
+
)
|
374 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
375 |
+
if world_size > 1:
|
376 |
+
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
377 |
+
config = config_class.from_pretrained(
|
378 |
+
model_args.model_name_or_path,
|
379 |
+
num_labels=1,
|
380 |
+
torch_dtype=torch_dtype,
|
381 |
+
trust_remote_code=model_args.trust_remote_code,
|
382 |
+
cache_dir=model_args.cache_dir
|
383 |
+
)
|
384 |
+
if model_args.model_type in ['bloom', 'llama']:
|
385 |
+
model = model_class.from_pretrained(
|
386 |
+
model_args.model_name_or_path,
|
387 |
+
config=config,
|
388 |
+
load_in_8bit=model_args.load_in_8bit,
|
389 |
+
device_map=model_args.device_map,
|
390 |
+
trust_remote_code=model_args.trust_remote_code,
|
391 |
+
)
|
392 |
+
model.score = CastOutputToFloat(model.score)
|
393 |
+
else:
|
394 |
+
model = model_class.from_pretrained(
|
395 |
+
model_args.model_name_or_path,
|
396 |
+
config=config,
|
397 |
+
cache_dir=model_args.cache_dir,
|
398 |
+
ignore_mismatched_sizes=True
|
399 |
+
)
|
400 |
+
model.to(training_args.device)
|
401 |
+
else:
|
402 |
+
raise ValueError(f"Error, model_name_or_path is None, RM must be loaded from a pre-trained model")
|
403 |
+
|
404 |
+
# Load tokenizer
|
405 |
+
if model_args.model_type == "bloom":
|
406 |
+
model_args.use_fast_tokenizer = True
|
407 |
+
tokenizer_kwargs = {
|
408 |
+
"cache_dir": model_args.cache_dir,
|
409 |
+
"use_fast": model_args.use_fast_tokenizer,
|
410 |
+
"trust_remote_code": model_args.trust_remote_code,
|
411 |
+
}
|
412 |
+
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
413 |
+
if not tokenizer_name_or_path:
|
414 |
+
tokenizer_name_or_path = model_args.model_name_or_path
|
415 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
416 |
+
if tokenizer.pad_token_id is None:
|
417 |
+
tokenizer.pad_token_id = 0
|
418 |
+
|
419 |
+
if training_args.use_peft:
|
420 |
+
if training_args.peft_path is not None:
|
421 |
+
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
|
422 |
+
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
|
423 |
+
else:
|
424 |
+
logger.info("Init new peft model")
|
425 |
+
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
|
426 |
+
if target_modules and 'all' in target_modules:
|
427 |
+
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
|
428 |
+
modules_to_save = training_args.modules_to_save
|
429 |
+
if modules_to_save is not None:
|
430 |
+
modules_to_save = modules_to_save.split(',')
|
431 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
432 |
+
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
|
433 |
+
peft_config = LoraConfig(
|
434 |
+
task_type=TaskType.SEQ_CLS,
|
435 |
+
target_modules=target_modules,
|
436 |
+
inference_mode=False,
|
437 |
+
r=training_args.lora_rank,
|
438 |
+
lora_alpha=training_args.lora_alpha,
|
439 |
+
lora_dropout=training_args.lora_dropout,
|
440 |
+
modules_to_save=modules_to_save)
|
441 |
+
model = get_peft_model(model, peft_config)
|
442 |
+
if model_args.load_in_8bit:
|
443 |
+
model = prepare_model_for_int8_training(model)
|
444 |
+
model.print_trainable_parameters()
|
445 |
+
else:
|
446 |
+
logger.info("Full parameters training")
|
447 |
+
print_trainable_parameters(model)
|
448 |
+
|
449 |
+
# Get reward dataset for tuning the reward model.
|
450 |
+
if data_args.dataset_name is not None:
|
451 |
+
# Downloading and loading a dataset from the hub.
|
452 |
+
raw_datasets = load_dataset(
|
453 |
+
data_args.dataset_name,
|
454 |
+
data_args.dataset_config_name,
|
455 |
+
cache_dir=model_args.cache_dir,
|
456 |
+
)
|
457 |
+
if "validation" not in raw_datasets.keys():
|
458 |
+
raw_datasets["validation"] = load_dataset(
|
459 |
+
data_args.dataset_name,
|
460 |
+
data_args.dataset_config_name,
|
461 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
462 |
+
cache_dir=model_args.cache_dir,
|
463 |
+
)
|
464 |
+
raw_datasets["train"] = load_dataset(
|
465 |
+
data_args.dataset_name,
|
466 |
+
data_args.dataset_config_name,
|
467 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
468 |
+
cache_dir=model_args.cache_dir,
|
469 |
+
)
|
470 |
+
else:
|
471 |
+
data_files = {}
|
472 |
+
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
|
473 |
+
train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
474 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
475 |
+
logger.info(f"train files: {', '.join(train_data_files)}")
|
476 |
+
data_files["train"] = train_data_files
|
477 |
+
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
|
478 |
+
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
479 |
+
f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
|
480 |
+
logger.info(f"eval files: {', '.join(eval_data_files)}")
|
481 |
+
data_files["validation"] = eval_data_files
|
482 |
+
raw_datasets = load_dataset(
|
483 |
+
'json',
|
484 |
+
data_files=data_files,
|
485 |
+
cache_dir=model_args.cache_dir,
|
486 |
+
)
|
487 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
488 |
+
if "validation" not in raw_datasets.keys():
|
489 |
+
raw_datasets["validation"] = load_dataset(
|
490 |
+
'json',
|
491 |
+
data_files=data_files,
|
492 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
493 |
+
cache_dir=model_args.cache_dir,
|
494 |
+
)
|
495 |
+
raw_datasets["train"] = load_dataset(
|
496 |
+
'json',
|
497 |
+
data_files=data_files,
|
498 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
499 |
+
cache_dir=model_args.cache_dir,
|
500 |
+
)
|
501 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
502 |
+
|
503 |
+
# Preprocessing the datasets
|
504 |
+
full_max_length = data_args.max_source_length + data_args.max_target_length
|
505 |
+
|
506 |
+
def preprocess_reward_function(examples):
|
507 |
+
"""
|
508 |
+
Turn the dataset into pairs of Question + Answer, where input_ids_chosen is the preferred question + answer
|
509 |
+
and text_rejected is the other.
|
510 |
+
"""
|
511 |
+
new_examples = {
|
512 |
+
"input_ids_chosen": [],
|
513 |
+
"attention_mask_chosen": [],
|
514 |
+
"input_ids_rejected": [],
|
515 |
+
"attention_mask_rejected": [],
|
516 |
+
}
|
517 |
+
for question, chosen, rejected in zip(examples["question"], examples["response_chosen"],
|
518 |
+
examples["response_rejected"]):
|
519 |
+
tokenized_chosen = tokenizer("Question: " + question + "\n\nAnswer: " + chosen)
|
520 |
+
tokenized_rejected = tokenizer("Question: " + question + "\n\nAnswer: " + rejected)
|
521 |
+
|
522 |
+
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
|
523 |
+
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
|
524 |
+
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
|
525 |
+
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
|
526 |
+
|
527 |
+
return new_examples
|
528 |
+
|
529 |
+
train_dataset = None
|
530 |
+
max_train_samples = 0
|
531 |
+
if training_args.do_train:
|
532 |
+
if "train" not in raw_datasets:
|
533 |
+
raise ValueError("--do_train requires a train dataset")
|
534 |
+
train_dataset = raw_datasets['train']
|
535 |
+
max_train_samples = len(train_dataset)
|
536 |
+
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
|
537 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
538 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
539 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
540 |
+
with training_args.main_process_first(desc="Train dataset tokenization"):
|
541 |
+
tokenized_dataset = train_dataset.shuffle().map(
|
542 |
+
preprocess_reward_function,
|
543 |
+
batched=True,
|
544 |
+
num_proc=data_args.preprocessing_num_workers,
|
545 |
+
remove_columns=train_dataset.column_names,
|
546 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
547 |
+
desc="Running tokenizer on dataset",
|
548 |
+
)
|
549 |
+
train_dataset = tokenized_dataset.filter(
|
550 |
+
lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len(
|
551 |
+
x['input_ids_chosen']) <= full_max_length
|
552 |
+
)
|
553 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
554 |
+
logger.debug("Tokenized training example:")
|
555 |
+
logger.debug(tokenizer.decode(train_dataset[0]['input_ids_chosen']))
|
556 |
+
|
557 |
+
eval_dataset = None
|
558 |
+
max_eval_samples = 0
|
559 |
+
if training_args.do_eval:
|
560 |
+
with training_args.main_process_first(desc="Eval dataset tokenization"):
|
561 |
+
if "validation" not in raw_datasets:
|
562 |
+
raise ValueError("--do_eval requires a validation dataset")
|
563 |
+
eval_dataset = raw_datasets["validation"]
|
564 |
+
max_eval_samples = len(eval_dataset)
|
565 |
+
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
|
566 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
567 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
568 |
+
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
|
569 |
+
tokenized_dataset = eval_dataset.map(
|
570 |
+
preprocess_reward_function,
|
571 |
+
batched=True,
|
572 |
+
num_proc=data_args.preprocessing_num_workers,
|
573 |
+
remove_columns=eval_dataset.column_names,
|
574 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
575 |
+
desc="Running tokenizer on dataset",
|
576 |
+
)
|
577 |
+
eval_dataset = tokenized_dataset.filter(
|
578 |
+
lambda x: 0 < len(x['input_ids_rejected']) <= full_max_length and 0 < len(
|
579 |
+
x['input_ids_chosen']) <= full_max_length
|
580 |
+
)
|
581 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
582 |
+
logger.debug("Tokenized eval example:")
|
583 |
+
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids_chosen']))
|
584 |
+
|
585 |
+
# Initialize our Trainer
|
586 |
+
if training_args.gradient_checkpointing:
|
587 |
+
model.gradient_checkpointing_enable()
|
588 |
+
model.config.use_cache = False
|
589 |
+
else:
|
590 |
+
model.config.use_cache = True
|
591 |
+
model.enable_input_require_grads()
|
592 |
+
if torch.cuda.device_count() > 1:
|
593 |
+
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
594 |
+
model.is_parallelizable = True
|
595 |
+
model.model_parallel = True
|
596 |
+
trainer = RewardTrainer(
|
597 |
+
model=model,
|
598 |
+
args=training_args,
|
599 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
600 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
601 |
+
tokenizer=tokenizer,
|
602 |
+
compute_metrics=compute_metrics,
|
603 |
+
data_collator=RewardDataCollatorWithPadding(
|
604 |
+
tokenizer=tokenizer, max_length=full_max_length, padding="max_length"
|
605 |
+
),
|
606 |
+
)
|
607 |
+
|
608 |
+
# Training
|
609 |
+
if training_args.do_train:
|
610 |
+
logger.info("*** Train ***")
|
611 |
+
logger.debug(f"Train dataloader example: {next(iter(trainer.get_train_dataloader()))}")
|
612 |
+
checkpoint = None
|
613 |
+
if training_args.resume_from_checkpoint is not None:
|
614 |
+
checkpoint = training_args.resume_from_checkpoint
|
615 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
616 |
+
|
617 |
+
metrics = train_result.metrics
|
618 |
+
metrics["train_samples"] = max_train_samples
|
619 |
+
logger.debug(f"Training metrics: {metrics}")
|
620 |
+
trainer.log_metrics("train", metrics)
|
621 |
+
trainer.save_metrics("train", metrics)
|
622 |
+
trainer.save_state()
|
623 |
+
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
|
624 |
+
save_model(training_args.output_dir, model, tokenizer, training_args)
|
625 |
+
|
626 |
+
# Evaluation
|
627 |
+
if training_args.do_eval and trainer.is_world_process_zero():
|
628 |
+
logger.info("*** Evaluate ***")
|
629 |
+
metrics = trainer.evaluate()
|
630 |
+
|
631 |
+
metrics["eval_samples"] = max_eval_samples
|
632 |
+
try:
|
633 |
+
perplexity = math.exp(metrics["eval_loss"])
|
634 |
+
except OverflowError:
|
635 |
+
perplexity = float("inf")
|
636 |
+
metrics["perplexity"] = perplexity
|
637 |
+
logger.debug(f"Eval metrics: {metrics}")
|
638 |
+
trainer.log_metrics("eval", metrics)
|
639 |
+
trainer.save_metrics("eval", metrics)
|
640 |
+
|
641 |
+
|
642 |
+
if __name__ == "__main__":
|
643 |
+
main()
|
MedicalGPT-main/rl_training.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
@author:XuMing(xuming624@qq.com)
|
4 |
+
@description: Train a model from SFT using PPO
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from glob import glob
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from datasets import load_dataset
|
14 |
+
from loguru import logger
|
15 |
+
from peft import LoraConfig, TaskType
|
16 |
+
from tqdm import tqdm
|
17 |
+
from transformers import (
|
18 |
+
AutoConfig,
|
19 |
+
AutoModelForSequenceClassification,
|
20 |
+
BloomForCausalLM,
|
21 |
+
AutoModelForCausalLM,
|
22 |
+
AutoModel,
|
23 |
+
LlamaTokenizer,
|
24 |
+
LlamaForCausalLM,
|
25 |
+
BloomTokenizerFast,
|
26 |
+
AutoTokenizer,
|
27 |
+
HfArgumentParser,
|
28 |
+
)
|
29 |
+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
30 |
+
|
31 |
+
from supervised_finetuning import get_conv_template
|
32 |
+
|
33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
|
34 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
35 |
+
|
36 |
+
MODEL_CLASSES = {
|
37 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
38 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
39 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
40 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
41 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class ScriptArguments:
|
47 |
+
"""
|
48 |
+
The name of the Casual LM model we wish to fine with PPO
|
49 |
+
"""
|
50 |
+
# Model arguments
|
51 |
+
model_type: str = field(
|
52 |
+
default=None,
|
53 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
54 |
+
)
|
55 |
+
model_name_or_path: Optional[str] = field(
|
56 |
+
default=None, metadata={"help": "The model checkpoint for weights initialization."}
|
57 |
+
)
|
58 |
+
reward_model_name_or_path: Optional[str] = field(default=None, metadata={"help": "The reward model name"})
|
59 |
+
tokenizer_name_or_path: Optional[str] = field(
|
60 |
+
default=None, metadata={"help": "The tokenizer for weights initialization."}
|
61 |
+
)
|
62 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
63 |
+
cache_dir: Optional[str] = field(
|
64 |
+
default=None,
|
65 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
66 |
+
)
|
67 |
+
use_fast_tokenizer: bool = field(
|
68 |
+
default=False,
|
69 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
70 |
+
)
|
71 |
+
torch_dtype: Optional[str] = field(
|
72 |
+
default=None,
|
73 |
+
metadata={
|
74 |
+
"help": (
|
75 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
76 |
+
"dtype will be automatically derived from the model's weights."
|
77 |
+
),
|
78 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
79 |
+
},
|
80 |
+
)
|
81 |
+
device_map: Optional[str] = field(
|
82 |
+
default="auto",
|
83 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
84 |
+
)
|
85 |
+
trust_remote_code: bool = field(
|
86 |
+
default=True,
|
87 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
88 |
+
)
|
89 |
+
# Dataset arguments
|
90 |
+
dataset_name: Optional[str] = field(
|
91 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
92 |
+
)
|
93 |
+
dataset_config_name: Optional[str] = field(
|
94 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
95 |
+
)
|
96 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The input jsonl data file folder."})
|
97 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."}, )
|
98 |
+
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The template name."})
|
99 |
+
batch_size: Optional[int] = field(default=8, metadata={"help": "Batch size"})
|
100 |
+
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "PPO minibatch size"})
|
101 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
102 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
103 |
+
min_target_length: Optional[int] = field(default=4, metadata={"help": "Min length of output text"})
|
104 |
+
max_train_samples: Optional[int] = field(
|
105 |
+
default=None,
|
106 |
+
metadata={
|
107 |
+
"help": (
|
108 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
109 |
+
"value if set."
|
110 |
+
)
|
111 |
+
},
|
112 |
+
)
|
113 |
+
max_eval_samples: Optional[int] = field(
|
114 |
+
default=None,
|
115 |
+
metadata={
|
116 |
+
"help": (
|
117 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
118 |
+
"value if set."
|
119 |
+
)
|
120 |
+
},
|
121 |
+
)
|
122 |
+
overwrite_cache: bool = field(
|
123 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
124 |
+
)
|
125 |
+
validation_split_percentage: Optional[int] = field(
|
126 |
+
default=1,
|
127 |
+
metadata={
|
128 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
129 |
+
},
|
130 |
+
)
|
131 |
+
preprocessing_num_workers: Optional[int] = field(
|
132 |
+
default=None, metadata={"help": "The number of processes to use for the preprocessing."},
|
133 |
+
)
|
134 |
+
# Training arguments
|
135 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
136 |
+
target_modules: Optional[str] = field(default=None)
|
137 |
+
lora_rank: Optional[int] = field(default=8)
|
138 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
139 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
140 |
+
modules_to_save: Optional[str] = field(default=None)
|
141 |
+
peft_path: Optional[str] = field(default=None)
|
142 |
+
|
143 |
+
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
144 |
+
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the validation set."})
|
145 |
+
early_stopping: Optional[bool] = field(default=False, metadata={"help": "Whether to early stop"})
|
146 |
+
target_kl: Optional[float] = field(default=0.1, metadata={"help": "The kl target for early stopping"})
|
147 |
+
reward_baseline: Optional[float] = field(
|
148 |
+
default=0.0, metadata={"help": "Baseline value that is subtracted from the reward"},
|
149 |
+
)
|
150 |
+
init_kl_coef: Optional[float] = field(
|
151 |
+
default=0.2, metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
|
152 |
+
)
|
153 |
+
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
154 |
+
learning_rate: Optional[float] = field(default=1.5e-5, metadata={"help": "Learning rate"})
|
155 |
+
gradient_accumulation_steps: Optional[int] = field(
|
156 |
+
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
157 |
+
)
|
158 |
+
save_steps: Optional[int] = field(default=50, metadata={"help": "X steps to save the model"})
|
159 |
+
output_dir: Optional[str] = field(default="outputs-rl", metadata={"help": "The output directory"})
|
160 |
+
seed: Optional[int] = field(default=0, metadata={"help": "Seed"})
|
161 |
+
max_steps: Optional[int] = field(default=200, metadata={"help": "Number of steps to train"})
|
162 |
+
report_to: Optional[str] = field(default="tensorboard", metadata={"help": "Report to wandb or tensorboard"})
|
163 |
+
|
164 |
+
def __post_init__(self):
|
165 |
+
if self.model_type is None:
|
166 |
+
raise ValueError("You must specify a valid model_type to run training.")
|
167 |
+
if self.model_name_or_path is None:
|
168 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
169 |
+
if self.reward_model_name_or_path is None:
|
170 |
+
raise ValueError("You must specify a valid reward_model_name_or_path to run training.")
|
171 |
+
|
172 |
+
|
173 |
+
def print_trainable_parameters(model):
|
174 |
+
"""
|
175 |
+
Prints the number of trainable parameters in the model.
|
176 |
+
"""
|
177 |
+
trainable_params = 0
|
178 |
+
all_param = 0
|
179 |
+
for _, param in model.named_parameters():
|
180 |
+
all_param += param.numel()
|
181 |
+
if param.requires_grad:
|
182 |
+
trainable_params += param.numel()
|
183 |
+
print(
|
184 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
def get_reward_model_output(reward_model, reward_tokenizer, question, answer, device):
|
189 |
+
"""
|
190 |
+
Get the reward score for a given question and answer pair.
|
191 |
+
"""
|
192 |
+
inputs = reward_tokenizer(question, answer, return_tensors='pt').to(device)
|
193 |
+
score = reward_model(**inputs).logits[0].cpu().detach()
|
194 |
+
|
195 |
+
return score
|
196 |
+
|
197 |
+
|
198 |
+
def calculate_rewards(reward_score_outputs, reward_baseline=0):
|
199 |
+
"""
|
200 |
+
Calculate the reward for a given score output.
|
201 |
+
:param reward_score_outputs:
|
202 |
+
:param reward_baseline:
|
203 |
+
:return:
|
204 |
+
"""
|
205 |
+
rewards = []
|
206 |
+
for score in reward_score_outputs:
|
207 |
+
if isinstance(score, torch.Tensor) and score.numel() == 1:
|
208 |
+
reward_value = score.item() - reward_baseline
|
209 |
+
rewards.append(torch.tensor(reward_value))
|
210 |
+
else:
|
211 |
+
# Use the average of the tensor elements as `score` is multiple elements
|
212 |
+
reward_value = torch.mean(score).item() - reward_baseline
|
213 |
+
rewards.append(torch.tensor(reward_value))
|
214 |
+
return rewards
|
215 |
+
|
216 |
+
|
217 |
+
def main():
|
218 |
+
parser = HfArgumentParser(ScriptArguments)
|
219 |
+
args = parser.parse_args_into_dataclasses()[0]
|
220 |
+
logger.info(f"Parse args: {args}")
|
221 |
+
|
222 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
223 |
+
if args.model_type == 'bloom':
|
224 |
+
args.use_fast_tokenizer = True
|
225 |
+
# Load tokenizer
|
226 |
+
tokenizer_kwargs = {
|
227 |
+
"cache_dir": args.cache_dir,
|
228 |
+
"use_fast": args.use_fast_tokenizer,
|
229 |
+
"trust_remote_code": args.trust_remote_code,
|
230 |
+
}
|
231 |
+
tokenizer_name_or_path = args.tokenizer_name_or_path
|
232 |
+
if not tokenizer_name_or_path:
|
233 |
+
tokenizer_name_or_path = args.model_name_or_path
|
234 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
235 |
+
if tokenizer.pad_token_id is None:
|
236 |
+
tokenizer.pad_token_id = 0 # set as the <unk> token
|
237 |
+
|
238 |
+
logger.info("Load model")
|
239 |
+
peft_config = LoraConfig(
|
240 |
+
task_type=TaskType.CAUSAL_LM,
|
241 |
+
target_modules=args.target_modules,
|
242 |
+
inference_mode=False,
|
243 |
+
r=args.lora_rank,
|
244 |
+
lora_alpha=args.lora_alpha,
|
245 |
+
lora_dropout=args.lora_dropout,
|
246 |
+
)
|
247 |
+
torch_dtype = (
|
248 |
+
args.torch_dtype
|
249 |
+
if args.torch_dtype in ["auto", None]
|
250 |
+
else getattr(torch, args.torch_dtype)
|
251 |
+
)
|
252 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
253 |
+
if world_size > 1:
|
254 |
+
args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
255 |
+
config = config_class.from_pretrained(
|
256 |
+
args.model_name_or_path,
|
257 |
+
torch_dtype=torch_dtype,
|
258 |
+
trust_remote_code=args.trust_remote_code,
|
259 |
+
cache_dir=args.cache_dir
|
260 |
+
)
|
261 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
262 |
+
args.model_name_or_path,
|
263 |
+
config=config,
|
264 |
+
load_in_8bit=args.load_in_8bit,
|
265 |
+
device_map=args.device_map,
|
266 |
+
trust_remote_code=args.trust_remote_code,
|
267 |
+
peft_config=peft_config if args.use_peft else None,
|
268 |
+
)
|
269 |
+
print_trainable_parameters(model)
|
270 |
+
# Load reward model
|
271 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
272 |
+
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
273 |
+
args.reward_model_name_or_path,
|
274 |
+
config=config,
|
275 |
+
load_in_8bit=args.load_in_8bit,
|
276 |
+
trust_remote_code=args.trust_remote_code,
|
277 |
+
)
|
278 |
+
reward_model.to(device)
|
279 |
+
reward_tokenizer = AutoTokenizer.from_pretrained(
|
280 |
+
args.reward_model_name_or_path, **tokenizer_kwargs
|
281 |
+
)
|
282 |
+
|
283 |
+
# Get datasets
|
284 |
+
if args.dataset_name is not None:
|
285 |
+
# Downloading and loading a dataset from the hub.
|
286 |
+
raw_datasets = load_dataset(
|
287 |
+
args.dataset_name,
|
288 |
+
args.dataset_config_name,
|
289 |
+
cache_dir=args.cache_dir,
|
290 |
+
)
|
291 |
+
if "validation" not in raw_datasets.keys():
|
292 |
+
raw_datasets["validation"] = load_dataset(
|
293 |
+
args.dataset_name,
|
294 |
+
args.dataset_config_name,
|
295 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
296 |
+
cache_dir=args.cache_dir,
|
297 |
+
)
|
298 |
+
raw_datasets["train"] = load_dataset(
|
299 |
+
args.dataset_name,
|
300 |
+
args.dataset_config_name,
|
301 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
302 |
+
cache_dir=args.cache_dir,
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
data_files = {}
|
306 |
+
if args.train_file_dir is not None and os.path.exists(args.train_file_dir):
|
307 |
+
train_data_files = glob(f'{args.train_file_dir}/**/*.json', recursive=True) + glob(
|
308 |
+
f'{args.train_file_dir}/**/*.jsonl', recursive=True)
|
309 |
+
logger.info(f"train files: {', '.join(train_data_files)}")
|
310 |
+
data_files["train"] = train_data_files
|
311 |
+
if args.validation_file_dir is not None and os.path.exists(args.validation_file_dir):
|
312 |
+
eval_data_files = glob(f'{args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
313 |
+
f'{args.validation_file_dir}/**/*.jsonl', recursive=True)
|
314 |
+
logger.info(f"eval files: {', '.join(eval_data_files)}")
|
315 |
+
data_files["validation"] = eval_data_files
|
316 |
+
raw_datasets = load_dataset(
|
317 |
+
'json',
|
318 |
+
data_files=data_files,
|
319 |
+
cache_dir=args.cache_dir,
|
320 |
+
)
|
321 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
322 |
+
if "validation" not in raw_datasets.keys():
|
323 |
+
raw_datasets["validation"] = load_dataset(
|
324 |
+
'json',
|
325 |
+
data_files=data_files,
|
326 |
+
split=f"train[:{args.validation_split_percentage}%]",
|
327 |
+
cache_dir=args.cache_dir,
|
328 |
+
)
|
329 |
+
raw_datasets["train"] = load_dataset(
|
330 |
+
'json',
|
331 |
+
data_files=data_files,
|
332 |
+
split=f"train[{args.validation_split_percentage}%:]",
|
333 |
+
cache_dir=args.cache_dir,
|
334 |
+
)
|
335 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
336 |
+
|
337 |
+
# Preprocessing the datasets
|
338 |
+
max_source_length = args.max_source_length
|
339 |
+
max_target_length = args.max_target_length
|
340 |
+
prompt_template = get_conv_template(args.template_name)
|
341 |
+
|
342 |
+
def preprocess_function(examples):
|
343 |
+
new_examples = {
|
344 |
+
"query": [],
|
345 |
+
"input_ids": [],
|
346 |
+
}
|
347 |
+
roles = ["human", "gpt"]
|
348 |
+
|
349 |
+
def get_prompt(examples):
|
350 |
+
for i, source in enumerate(examples['conversations']):
|
351 |
+
if len(source) < 2:
|
352 |
+
continue
|
353 |
+
data_role = source[0].get("from", "")
|
354 |
+
if data_role not in roles or data_role != roles[0]:
|
355 |
+
# Skip the first one if it is not from human
|
356 |
+
source = source[1:]
|
357 |
+
if len(source) < 2:
|
358 |
+
continue
|
359 |
+
messages = []
|
360 |
+
for j, sentence in enumerate(source):
|
361 |
+
data_role = sentence.get("from", "")
|
362 |
+
if data_role not in roles:
|
363 |
+
logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
|
364 |
+
break
|
365 |
+
if data_role == roles[j % 2]:
|
366 |
+
messages.append(sentence["value"])
|
367 |
+
if len(messages) < 2 or len(messages) % 2 != 0:
|
368 |
+
continue
|
369 |
+
# Convert the list to pairs of elements
|
370 |
+
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
|
371 |
+
yield prompt_template.get_prompt(history_messages)
|
372 |
+
|
373 |
+
for prompt in get_prompt(examples):
|
374 |
+
for i in range(len(prompt) // 2):
|
375 |
+
source_txt = prompt[2 * i]
|
376 |
+
tokenized_question = tokenizer(
|
377 |
+
source_txt, truncation=True, max_length=max_source_length, padding="max_length",
|
378 |
+
return_tensors="pt"
|
379 |
+
)
|
380 |
+
new_examples["query"].append(source_txt)
|
381 |
+
new_examples["input_ids"].append(tokenized_question["input_ids"])
|
382 |
+
|
383 |
+
return new_examples
|
384 |
+
|
385 |
+
# Preprocess the dataset
|
386 |
+
train_dataset = None
|
387 |
+
if args.do_train:
|
388 |
+
if "train" not in raw_datasets:
|
389 |
+
raise ValueError("--do_train requires a train dataset")
|
390 |
+
train_dataset = raw_datasets['train']
|
391 |
+
if args.max_train_samples is not None and args.max_train_samples > 0:
|
392 |
+
max_train_samples = min(len(train_dataset), args.max_train_samples)
|
393 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
394 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
395 |
+
tokenized_dataset = train_dataset.shuffle().map(
|
396 |
+
preprocess_function,
|
397 |
+
batched=True,
|
398 |
+
num_proc=args.preprocessing_num_workers,
|
399 |
+
remove_columns=train_dataset.column_names,
|
400 |
+
load_from_cache_file=not args.overwrite_cache,
|
401 |
+
desc="Running tokenizer on dataset",
|
402 |
+
)
|
403 |
+
train_dataset = tokenized_dataset.filter(
|
404 |
+
lambda x: len(x['input_ids']) > 0
|
405 |
+
)
|
406 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
407 |
+
|
408 |
+
def collator(data):
|
409 |
+
return dict((key, [d[key] for d in data]) for key in data[0])
|
410 |
+
|
411 |
+
output_dir = args.output_dir
|
412 |
+
config = PPOConfig(
|
413 |
+
steps=args.max_steps,
|
414 |
+
model_name=args.model_name_or_path,
|
415 |
+
learning_rate=args.learning_rate,
|
416 |
+
log_with=args.report_to,
|
417 |
+
batch_size=args.batch_size,
|
418 |
+
mini_batch_size=args.mini_batch_size,
|
419 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
420 |
+
optimize_cuda_cache=True,
|
421 |
+
early_stopping=args.early_stopping,
|
422 |
+
target_kl=args.target_kl,
|
423 |
+
seed=args.seed,
|
424 |
+
init_kl_coef=args.init_kl_coef,
|
425 |
+
adap_kl_ctrl=args.adap_kl_ctrl,
|
426 |
+
project_kwargs={"logging_dir": output_dir},
|
427 |
+
)
|
428 |
+
# Set seed before initializing value head for deterministic eval
|
429 |
+
set_seed(config.seed)
|
430 |
+
|
431 |
+
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
432 |
+
trainer = PPOTrainer(
|
433 |
+
config,
|
434 |
+
model,
|
435 |
+
ref_model=None,
|
436 |
+
tokenizer=tokenizer,
|
437 |
+
dataset=train_dataset,
|
438 |
+
data_collator=collator,
|
439 |
+
)
|
440 |
+
|
441 |
+
# These arguments are passed to the `generate` function of the PPOTrainer
|
442 |
+
generation_kwargs = {
|
443 |
+
"max_new_tokens": max_target_length,
|
444 |
+
"temperature": 1.0,
|
445 |
+
"repetition_penalty": 1.0,
|
446 |
+
"top_p": 1.0,
|
447 |
+
"do_sample": True,
|
448 |
+
}
|
449 |
+
|
450 |
+
def save_model(save_dir):
|
451 |
+
trainer.accelerator.unwrap_model(trainer.model).save_pretrained(save_dir)
|
452 |
+
trainer.tokenizer.save_pretrained(save_dir)
|
453 |
+
|
454 |
+
# Training
|
455 |
+
if args.do_train:
|
456 |
+
logger.info("*** Train ***")
|
457 |
+
total_steps = config.total_ppo_epochs
|
458 |
+
for step, batch in tqdm(enumerate(trainer.dataloader)):
|
459 |
+
if step >= total_steps:
|
460 |
+
break
|
461 |
+
question_tensors = batch["input_ids"]
|
462 |
+
question_tensors = [torch.LongTensor(i).to(device).squeeze(0) for i in question_tensors]
|
463 |
+
responses = []
|
464 |
+
response_tensors = []
|
465 |
+
for q_tensor in question_tensors:
|
466 |
+
response_tensor = trainer.generate(
|
467 |
+
q_tensor,
|
468 |
+
return_prompt=False,
|
469 |
+
**generation_kwargs,
|
470 |
+
)
|
471 |
+
r = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)[0]
|
472 |
+
responses.append(r)
|
473 |
+
response_tensors.append(response_tensor.squeeze(0))
|
474 |
+
batch["response"] = responses
|
475 |
+
|
476 |
+
# Compute reward score
|
477 |
+
score_outputs = [
|
478 |
+
get_reward_model_output(reward_model, reward_tokenizer, q, r, device) for q, r in
|
479 |
+
zip(batch["query"], batch["response"])
|
480 |
+
]
|
481 |
+
rewards = calculate_rewards(score_outputs, args.reward_baseline)
|
482 |
+
|
483 |
+
# Run PPO step
|
484 |
+
try:
|
485 |
+
stats = trainer.step(question_tensors, response_tensors, rewards)
|
486 |
+
trainer.log_stats(stats, batch, rewards)
|
487 |
+
logger.debug(f"Step {step}/{total_steps}: reward score:{score_outputs}")
|
488 |
+
except ValueError as e:
|
489 |
+
logger.warning(f"Failed to log stats for step {step}, because of {e}")
|
490 |
+
|
491 |
+
if step and step % args.save_steps == 0:
|
492 |
+
save_dir = os.path.join(output_dir, f"checkpoint-{step}")
|
493 |
+
save_model(save_dir)
|
494 |
+
# Save final model
|
495 |
+
save_model(output_dir)
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == "__main__":
|
499 |
+
main()
|
MedicalGPT-main/run_dpo.sh
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 python dpo_training.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/reward \
|
5 |
+
--validation_file_dir ./data/reward \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 1 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--use_peft True \
|
11 |
+
--max_train_samples 1000 \
|
12 |
+
--max_eval_samples 10 \
|
13 |
+
--max_steps 100 \
|
14 |
+
--eval_steps 20 \
|
15 |
+
--save_steps 50 \
|
16 |
+
--max_source_length 128 \
|
17 |
+
--max_target_length 128 \
|
18 |
+
--output_dir outputs-dpo-bloom-v1 \
|
19 |
+
--target_modules all \
|
20 |
+
--lora_rank 8 \
|
21 |
+
--lora_alpha 16 \
|
22 |
+
--lora_dropout 0.05 \
|
23 |
+
--torch_dtype float16 \
|
24 |
+
--fp16 True \
|
25 |
+
--device_map auto \
|
26 |
+
--report_to tensorboard \
|
27 |
+
--remove_unused_columns False \
|
28 |
+
--gradient_checkpointing True \
|
29 |
+
--cache_dir ./cache
|
MedicalGPT-main/run_pt.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 pretraining.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/pretrain \
|
5 |
+
--validation_file_dir ./data/pretrain \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 4 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--use_peft True \
|
11 |
+
--seed 42 \
|
12 |
+
--fp16 \
|
13 |
+
--max_train_samples 10000 \
|
14 |
+
--max_eval_samples 10 \
|
15 |
+
--num_train_epochs 0.5 \
|
16 |
+
--learning_rate 2e-4 \
|
17 |
+
--warmup_ratio 0.05 \
|
18 |
+
--weight_decay 0.01 \
|
19 |
+
--logging_strategy steps \
|
20 |
+
--logging_steps 10 \
|
21 |
+
--eval_steps 50 \
|
22 |
+
--evaluation_strategy steps \
|
23 |
+
--save_steps 500 \
|
24 |
+
--save_strategy steps \
|
25 |
+
--save_total_limit 3 \
|
26 |
+
--gradient_accumulation_steps 1 \
|
27 |
+
--preprocessing_num_workers 1 \
|
28 |
+
--block_size 1024 \
|
29 |
+
--output_dir outputs-pt-bloom-v1 \
|
30 |
+
--overwrite_output_dir \
|
31 |
+
--ddp_timeout 30000 \
|
32 |
+
--logging_first_step True \
|
33 |
+
--target_modules all \
|
34 |
+
--lora_rank 8 \
|
35 |
+
--lora_alpha 16 \
|
36 |
+
--lora_dropout 0.05 \
|
37 |
+
--torch_dtype float16 \
|
38 |
+
--device_map auto \
|
39 |
+
--report_to tensorboard \
|
40 |
+
--ddp_find_unused_parameters False \
|
41 |
+
--gradient_checkpointing True \
|
42 |
+
--cache_dir ./cache
|
MedicalGPT-main/run_rl.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 rl_training.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--reward_model_name_or_path OpenAssistant/reward-model-deberta-v3-large-v2 \
|
5 |
+
--torch_dtype float16 \
|
6 |
+
--device_map auto \
|
7 |
+
--train_file_dir ./data/finetune \
|
8 |
+
--validation_file_dir ./data/finetune \
|
9 |
+
--batch_size 8 \
|
10 |
+
--max_source_length 256 \
|
11 |
+
--max_target_length 256 \
|
12 |
+
--max_train_samples 1000 \
|
13 |
+
--use_peft True \
|
14 |
+
--lora_rank 8 \
|
15 |
+
--lora_alpha 16 \
|
16 |
+
--lora_dropout 0.05 \
|
17 |
+
--do_train \
|
18 |
+
--max_steps 100 \
|
19 |
+
--learning_rate 1e-5 \
|
20 |
+
--save_steps 50 \
|
21 |
+
--output_dir outputs-rl-bloom-v1 \
|
22 |
+
--early_stopping True \
|
23 |
+
--target_kl 0.1 \
|
24 |
+
--reward_baseline 0.0
|
MedicalGPT-main/run_rm.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 python reward_modeling.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/reward \
|
5 |
+
--validation_file_dir ./data/reward \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 4 \
|
8 |
+
--do_train \
|
9 |
+
--use_peft True \
|
10 |
+
--seed 42 \
|
11 |
+
--max_train_samples 1000 \
|
12 |
+
--max_eval_samples 10 \
|
13 |
+
--num_train_epochs 1 \
|
14 |
+
--learning_rate 2e-5 \
|
15 |
+
--warmup_ratio 0.05 \
|
16 |
+
--weight_decay 0.001 \
|
17 |
+
--logging_strategy steps \
|
18 |
+
--logging_steps 10 \
|
19 |
+
--eval_steps 50 \
|
20 |
+
--evaluation_strategy steps \
|
21 |
+
--save_steps 500 \
|
22 |
+
--save_strategy steps \
|
23 |
+
--save_total_limit 3 \
|
24 |
+
--max_source_length 256 \
|
25 |
+
--max_target_length 256 \
|
26 |
+
--output_dir outputs-rm-bloom-v1 \
|
27 |
+
--overwrite_output_dir \
|
28 |
+
--ddp_timeout 30000 \
|
29 |
+
--logging_first_step True \
|
30 |
+
--target_modules all \
|
31 |
+
--lora_rank 8 \
|
32 |
+
--lora_alpha 16 \
|
33 |
+
--lora_dropout 0.05 \
|
34 |
+
--torch_dtype float32 \
|
35 |
+
--device_map auto \
|
36 |
+
--report_to tensorboard \
|
37 |
+
--ddp_find_unused_parameters False \
|
38 |
+
--remove_unused_columns False \
|
39 |
+
--gradient_checkpointing True
|
MedicalGPT-main/run_sft.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 supervised_finetuning.py \
|
2 |
+
--model_type bloom \
|
3 |
+
--model_name_or_path bigscience/bloomz-560m \
|
4 |
+
--train_file_dir ./data/finetune \
|
5 |
+
--validation_file_dir ./data/finetune \
|
6 |
+
--per_device_train_batch_size 4 \
|
7 |
+
--per_device_eval_batch_size 4 \
|
8 |
+
--do_train \
|
9 |
+
--do_eval \
|
10 |
+
--use_peft True \
|
11 |
+
--fp16 \
|
12 |
+
--max_train_samples 1000 \
|
13 |
+
--max_eval_samples 10 \
|
14 |
+
--num_train_epochs 1 \
|
15 |
+
--learning_rate 2e-5 \
|
16 |
+
--warmup_ratio 0.05 \
|
17 |
+
--weight_decay 0.05 \
|
18 |
+
--logging_strategy steps \
|
19 |
+
--logging_steps 10 \
|
20 |
+
--eval_steps 50 \
|
21 |
+
--evaluation_strategy steps \
|
22 |
+
--save_steps 500 \
|
23 |
+
--save_strategy steps \
|
24 |
+
--save_total_limit 3 \
|
25 |
+
--gradient_accumulation_steps 1 \
|
26 |
+
--preprocessing_num_workers 4 \
|
27 |
+
--output_dir outputs-sft-bloom-v1 \
|
28 |
+
--overwrite_output_dir \
|
29 |
+
--ddp_timeout 30000 \
|
30 |
+
--logging_first_step True \
|
31 |
+
--target_modules all \
|
32 |
+
--lora_rank 8 \
|
33 |
+
--lora_alpha 16 \
|
34 |
+
--lora_dropout 0.05 \
|
35 |
+
--torch_dtype float16 \
|
36 |
+
--device_map auto \
|
37 |
+
--report_to tensorboard \
|
38 |
+
--ddp_find_unused_parameters False \
|
39 |
+
--gradient_checkpointing True \
|
40 |
+
--cache_dir ./cache
|
MedicalGPT-main/run_training_dpo_pipeline.ipynb
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# Training Pipeline\n",
|
7 |
+
"[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) | [Open In Colab](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb)"
|
8 |
+
],
|
9 |
+
"metadata": {
|
10 |
+
"collapsed": false
|
11 |
+
}
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"tags": []
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# Stage 1: Continue Pretraining\n",
|
20 |
+
"\n",
|
21 |
+
"第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文本数据上二次预训练GPT模型,以注入领域知识\n",
|
22 |
+
"\n",
|
23 |
+
"| Stage 1: Continue Pretraining | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"#### 说明:\n",
|
31 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
32 |
+
"\n",
|
33 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m`\n",
|
34 |
+
"2. 数据集:PT阶段使用的是中文天龙八部小说部分文本和英文书籍部分文本,位于`data/pretrain`文件夹"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"source": [],
|
40 |
+
"metadata": {
|
41 |
+
"collapsed": false
|
42 |
+
}
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "markdown",
|
46 |
+
"metadata": {},
|
47 |
+
"source": [
|
48 |
+
"## 配置运行环境\n",
|
49 |
+
"\n",
|
50 |
+
"本地执行可注释以下配置环境的命令,colab执行要打开注释,用于配置环境\n",
|
51 |
+
"\n",
|
52 |
+
"colab建议使用T4 GPU训练,设置方式:`代码执行程序 -> 更改运行时类型 -> 运行时类型:Python3,硬件加速器:GPU,GPU类型:T4 -> 保存`\n",
|
53 |
+
"\n",
|
54 |
+
"步骤:\n",
|
55 |
+
"1. 下载最新代码到本地\n",
|
56 |
+
"2. 安装依赖包\n",
|
57 |
+
"\n",
|
58 |
+
"依赖包如下,保证最新版本:\n",
|
59 |
+
"\n",
|
60 |
+
"```\n",
|
61 |
+
"loguru\n",
|
62 |
+
"transformers\n",
|
63 |
+
"sentencepiece\n",
|
64 |
+
"datasets\n",
|
65 |
+
"tensorboard\n",
|
66 |
+
"tqdm\n",
|
67 |
+
"peft\n",
|
68 |
+
"trl\n",
|
69 |
+
"```"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": null,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"!git clone --depth 1 https://github.com/shibing624/MedicalGPT.git\n",
|
79 |
+
"%cd MedicalGPT\n",
|
80 |
+
"%ls\n",
|
81 |
+
"!pip install -r requirements.txt"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "markdown",
|
86 |
+
"metadata": {},
|
87 |
+
"source": [
|
88 |
+
"## Stage1 咱们开始吧\n",
|
89 |
+
"\n",
|
90 |
+
"训练步骤如下:\n",
|
91 |
+
"\n",
|
92 |
+
"1. 确认训练集\n",
|
93 |
+
"2. 执行训练脚本\n",
|
94 |
+
"\n",
|
95 |
+
"训练脚本的执行逻辑如下:\n",
|
96 |
+
"1. 导入依赖包\n",
|
97 |
+
"2. 设置参数\n",
|
98 |
+
"3. 定义各函数并加载训练集\n",
|
99 |
+
"4. 加载模型和tokenizer\n",
|
100 |
+
"5. 开始训练并评估\n",
|
101 |
+
"6. 查看训练结果\n",
|
102 |
+
"\n",
|
103 |
+
"**以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的**"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"%ls ./data/pretrain/"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"outputs": [],
|
119 |
+
"source": [
|
120 |
+
"!python pretraining.py \\\n",
|
121 |
+
" --model_type bloom \\\n",
|
122 |
+
" --model_name_or_path bigscience/bloomz-560m \\\n",
|
123 |
+
" --train_file_dir ./data/pretrain \\\n",
|
124 |
+
" --validation_file_dir ./data/pretrain \\\n",
|
125 |
+
" --per_device_train_batch_size 3 \\\n",
|
126 |
+
" --per_device_eval_batch_size 3 \\\n",
|
127 |
+
" --do_train \\\n",
|
128 |
+
" --do_eval \\\n",
|
129 |
+
" --use_peft True \\\n",
|
130 |
+
" --seed 42 \\\n",
|
131 |
+
" --fp16 \\\n",
|
132 |
+
" --max_train_samples 10000 \\\n",
|
133 |
+
" --max_eval_samples 10 \\\n",
|
134 |
+
" --num_train_epochs 1 \\\n",
|
135 |
+
" --learning_rate 2e-4 \\\n",
|
136 |
+
" --warmup_ratio 0.05 \\\n",
|
137 |
+
" --weight_decay 0.01 \\\n",
|
138 |
+
" --logging_strategy steps \\\n",
|
139 |
+
" --logging_steps 10 \\\n",
|
140 |
+
" --eval_steps 50 \\\n",
|
141 |
+
" --evaluation_strategy steps \\\n",
|
142 |
+
" --save_steps 500 \\\n",
|
143 |
+
" --save_strategy steps \\\n",
|
144 |
+
" --save_total_limit 3 \\\n",
|
145 |
+
" --gradient_accumulation_steps 1 \\\n",
|
146 |
+
" --preprocessing_num_workers 1 \\\n",
|
147 |
+
" --block_size 1024 \\\n",
|
148 |
+
" --output_dir outputs-pt-v1 \\\n",
|
149 |
+
" --overwrite_output_dir \\\n",
|
150 |
+
" --ddp_timeout 30000 \\\n",
|
151 |
+
" --logging_first_step True \\\n",
|
152 |
+
" --target_modules all \\\n",
|
153 |
+
" --lora_rank 8 \\\n",
|
154 |
+
" --lora_alpha 16 \\\n",
|
155 |
+
" --lora_dropout 0.05 \\\n",
|
156 |
+
" --torch_dtype float16 \\\n",
|
157 |
+
" --device_map auto \\\n",
|
158 |
+
" --report_to tensorboard \\\n",
|
159 |
+
" --ddp_find_unused_parameters False \\\n",
|
160 |
+
" --gradient_checkpointing True"
|
161 |
+
],
|
162 |
+
"metadata": {
|
163 |
+
"collapsed": false
|
164 |
+
}
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": null,
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"%ls -lh outputs-pt-v1"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "markdown",
|
177 |
+
"metadata": {},
|
178 |
+
"source": [
|
179 |
+
"模型训练结果:\n",
|
180 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
181 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "markdown",
|
186 |
+
"source": [
|
187 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
188 |
+
],
|
189 |
+
"metadata": {
|
190 |
+
"collapsed": false
|
191 |
+
}
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": null,
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
199 |
+
" --base_model_name_or_path bigscience/bloomz-560m --peft_model_path outputs-pt-v1 --output_dir merged-pt/"
|
200 |
+
],
|
201 |
+
"metadata": {
|
202 |
+
"collapsed": false
|
203 |
+
}
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"outputs": [],
|
209 |
+
"source": [
|
210 |
+
"%ls -lh merged-pt/"
|
211 |
+
],
|
212 |
+
"metadata": {
|
213 |
+
"collapsed": false
|
214 |
+
}
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "code",
|
218 |
+
"execution_count": null,
|
219 |
+
"outputs": [],
|
220 |
+
"source": [
|
221 |
+
"%cat merged-pt/config.json"
|
222 |
+
],
|
223 |
+
"metadata": {
|
224 |
+
"collapsed": false
|
225 |
+
}
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "markdown",
|
229 |
+
"metadata": {},
|
230 |
+
"source": [
|
231 |
+
"Stage1 增量预训练完成。"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"execution_count": null,
|
237 |
+
"metadata": {
|
238 |
+
"ExecuteTime": {
|
239 |
+
"start_time": "2023-06-15T13:56:17.032821Z",
|
240 |
+
"end_time": "2023-06-15T13:56:17.081153Z"
|
241 |
+
}
|
242 |
+
},
|
243 |
+
"outputs": [],
|
244 |
+
"source": []
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "markdown",
|
248 |
+
"source": [
|
249 |
+
"# Stage 2: Supervised FineTuning\n",
|
250 |
+
"\n",
|
251 |
+
"第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图\n",
|
252 |
+
"\n",
|
253 |
+
"| Stage 2: Supervised Fine-tuning | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |"
|
254 |
+
],
|
255 |
+
"metadata": {
|
256 |
+
"collapsed": false
|
257 |
+
}
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "markdown",
|
261 |
+
"source": [
|
262 |
+
"#### 说明:\n",
|
263 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
264 |
+
"\n",
|
265 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage1得到的预训练模型\n",
|
266 |
+
"2. 数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
|
267 |
+
],
|
268 |
+
"metadata": {
|
269 |
+
"collapsed": false
|
270 |
+
}
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "markdown",
|
274 |
+
"source": [
|
275 |
+
"## Stage2 咱们开始吧\n",
|
276 |
+
"\n",
|
277 |
+
"训练步骤如下:\n",
|
278 |
+
"\n",
|
279 |
+
"1. 确认训练集\n",
|
280 |
+
"2. 执行训练脚本\n",
|
281 |
+
"\n",
|
282 |
+
"训练脚本的执行逻辑如下:\n",
|
283 |
+
"1. 导入依赖包\n",
|
284 |
+
"2. 设置参数\n",
|
285 |
+
"3. 定义各函数并加载训练集\n",
|
286 |
+
"4. 加载模型和tokenizer\n",
|
287 |
+
"5. 开始训练并评估\n",
|
288 |
+
"6. 查看训练结果"
|
289 |
+
],
|
290 |
+
"metadata": {
|
291 |
+
"collapsed": false
|
292 |
+
}
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"execution_count": null,
|
297 |
+
"outputs": [],
|
298 |
+
"source": [
|
299 |
+
"%ls ./data/finetune"
|
300 |
+
],
|
301 |
+
"metadata": {
|
302 |
+
"collapsed": false,
|
303 |
+
"ExecuteTime": {
|
304 |
+
"start_time": "2023-06-15T13:58:38.778132Z",
|
305 |
+
"end_time": "2023-06-15T13:58:38.966506Z"
|
306 |
+
}
|
307 |
+
}
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": null,
|
312 |
+
"outputs": [],
|
313 |
+
"source": [
|
314 |
+
"!python supervised_finetuning.py \\\n",
|
315 |
+
" --model_type bloom \\\n",
|
316 |
+
" --model_name_or_path merged-pt \\\n",
|
317 |
+
" --train_file_dir ./data/finetune \\\n",
|
318 |
+
" --validation_file_dir ./data/finetune \\\n",
|
319 |
+
" --per_device_train_batch_size 4 \\\n",
|
320 |
+
" --per_device_eval_batch_size 4 \\\n",
|
321 |
+
" --do_train \\\n",
|
322 |
+
" --do_eval \\\n",
|
323 |
+
" --use_peft True \\\n",
|
324 |
+
" --fp16 \\\n",
|
325 |
+
" --max_train_samples 1000 \\\n",
|
326 |
+
" --max_eval_samples 10 \\\n",
|
327 |
+
" --num_train_epochs 1 \\\n",
|
328 |
+
" --learning_rate 2e-5 \\\n",
|
329 |
+
" --warmup_ratio 0.05 \\\n",
|
330 |
+
" --weight_decay 0.05 \\\n",
|
331 |
+
" --logging_strategy steps \\\n",
|
332 |
+
" --logging_steps 10 \\\n",
|
333 |
+
" --eval_steps 50 \\\n",
|
334 |
+
" --evaluation_strategy steps \\\n",
|
335 |
+
" --save_steps 500 \\\n",
|
336 |
+
" --save_strategy steps \\\n",
|
337 |
+
" --save_total_limit 3 \\\n",
|
338 |
+
" --gradient_accumulation_steps 1 \\\n",
|
339 |
+
" --preprocessing_num_workers 1 \\\n",
|
340 |
+
" --output_dir outputs-sft-v1 \\\n",
|
341 |
+
" --overwrite_output_dir \\\n",
|
342 |
+
" --ddp_timeout 30000 \\\n",
|
343 |
+
" --logging_first_step True \\\n",
|
344 |
+
" --target_modules all \\\n",
|
345 |
+
" --lora_rank 8 \\\n",
|
346 |
+
" --lora_alpha 16 \\\n",
|
347 |
+
" --lora_dropout 0.05 \\\n",
|
348 |
+
" --torch_dtype float16 \\\n",
|
349 |
+
" --device_map auto \\\n",
|
350 |
+
" --report_to tensorboard \\\n",
|
351 |
+
" --ddp_find_unused_parameters False \\\n",
|
352 |
+
" --gradient_checkpointing True"
|
353 |
+
],
|
354 |
+
"metadata": {
|
355 |
+
"collapsed": false
|
356 |
+
}
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"outputs": [],
|
362 |
+
"source": [
|
363 |
+
"%ls -lh outputs-sft-v1"
|
364 |
+
],
|
365 |
+
"metadata": {
|
366 |
+
"collapsed": false
|
367 |
+
}
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "markdown",
|
371 |
+
"source": [
|
372 |
+
"模型训练结果:\n",
|
373 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
374 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
375 |
+
],
|
376 |
+
"metadata": {
|
377 |
+
"collapsed": false
|
378 |
+
}
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "markdown",
|
382 |
+
"source": [
|
383 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
384 |
+
],
|
385 |
+
"metadata": {
|
386 |
+
"collapsed": false
|
387 |
+
}
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": null,
|
392 |
+
"outputs": [],
|
393 |
+
"source": [
|
394 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
395 |
+
" --base_model_name_or_path merged-pt --peft_model_path outputs-sft-v1 --output_dir merged-sft/"
|
396 |
+
],
|
397 |
+
"metadata": {
|
398 |
+
"collapsed": false
|
399 |
+
}
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "code",
|
403 |
+
"execution_count": null,
|
404 |
+
"outputs": [],
|
405 |
+
"source": [
|
406 |
+
"%ls -lh merged-sft/"
|
407 |
+
],
|
408 |
+
"metadata": {
|
409 |
+
"collapsed": false
|
410 |
+
}
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"cell_type": "code",
|
414 |
+
"execution_count": null,
|
415 |
+
"outputs": [],
|
416 |
+
"source": [
|
417 |
+
"%cat merged-sft/config.json"
|
418 |
+
],
|
419 |
+
"metadata": {
|
420 |
+
"collapsed": false
|
421 |
+
}
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"cell_type": "markdown",
|
425 |
+
"source": [
|
426 |
+
"Stage2 SFT训练完成。"
|
427 |
+
],
|
428 |
+
"metadata": {
|
429 |
+
"collapsed": false
|
430 |
+
}
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": null,
|
435 |
+
"outputs": [],
|
436 |
+
"source": [],
|
437 |
+
"metadata": {
|
438 |
+
"collapsed": false,
|
439 |
+
"ExecuteTime": {
|
440 |
+
"start_time": "2023-06-15T14:07:40.731186Z",
|
441 |
+
"end_time": "2023-06-15T14:07:40.752635Z"
|
442 |
+
}
|
443 |
+
}
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "markdown",
|
447 |
+
"source": [
|
448 |
+
"# Stage 3: DPO(Direct Preference Optimization)\n",
|
449 |
+
"\n",
|
450 |
+
"第三阶段:DPO(Direct Preference Optimization)直接偏好优化,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好\n",
|
451 |
+
"\n",
|
452 |
+
"| Stage 3: Direct Preference Optimization | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |"
|
453 |
+
],
|
454 |
+
"metadata": {
|
455 |
+
"collapsed": false
|
456 |
+
}
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "markdown",
|
460 |
+
"source": [
|
461 |
+
"#### 说明:\n",
|
462 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
463 |
+
"\n",
|
464 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
|
465 |
+
"2. 数据集:DPO阶段使用的是医疗reward数据,抽样了500条,位于`data/reward`文件夹"
|
466 |
+
],
|
467 |
+
"metadata": {
|
468 |
+
"collapsed": false
|
469 |
+
}
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "markdown",
|
473 |
+
"source": [
|
474 |
+
"## Stage3 咱们开始吧\n",
|
475 |
+
"\n",
|
476 |
+
"训练步骤如下:\n",
|
477 |
+
"\n",
|
478 |
+
"1. 确认训练集\n",
|
479 |
+
"2. 执行训练脚本\n",
|
480 |
+
"\n",
|
481 |
+
"训练脚本的执行逻辑如下:\n",
|
482 |
+
"1. 导入依赖包\n",
|
483 |
+
"2. 设置参数\n",
|
484 |
+
"3. 定义各函数并加载训练集\n",
|
485 |
+
"4. 加载模型和tokenizer\n",
|
486 |
+
"5. 开始训练并评估\n",
|
487 |
+
"6. 查看训练结果"
|
488 |
+
],
|
489 |
+
"metadata": {
|
490 |
+
"collapsed": false
|
491 |
+
}
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "code",
|
495 |
+
"execution_count": null,
|
496 |
+
"outputs": [],
|
497 |
+
"source": [
|
498 |
+
"%ls ./data/reward/"
|
499 |
+
],
|
500 |
+
"metadata": {
|
501 |
+
"collapsed": false
|
502 |
+
}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "code",
|
506 |
+
"execution_count": null,
|
507 |
+
"outputs": [],
|
508 |
+
"source": [
|
509 |
+
"!python dpo_training.py \\\n",
|
510 |
+
" --model_type bloom \\\n",
|
511 |
+
" --model_name_or_path merged-sft \\\n",
|
512 |
+
" --train_file_dir ./data/reward \\\n",
|
513 |
+
" --validation_file_dir ./data/reward \\\n",
|
514 |
+
" --per_device_train_batch_size 3 \\\n",
|
515 |
+
" --per_device_eval_batch_size 1 \\\n",
|
516 |
+
" --do_train \\\n",
|
517 |
+
" --do_eval \\\n",
|
518 |
+
" --use_peft True \\\n",
|
519 |
+
" --max_train_samples 1000 \\\n",
|
520 |
+
" --max_eval_samples 10 \\\n",
|
521 |
+
" --max_steps 100 \\\n",
|
522 |
+
" --eval_steps 10 \\\n",
|
523 |
+
" --save_steps 50 \\\n",
|
524 |
+
" --max_source_length 128 \\\n",
|
525 |
+
" --max_target_length 128 \\\n",
|
526 |
+
" --output_dir outputs-dpo-v1 \\\n",
|
527 |
+
" --target_modules all \\\n",
|
528 |
+
" --lora_rank 8 \\\n",
|
529 |
+
" --lora_alpha 16 \\\n",
|
530 |
+
" --lora_dropout 0.05 \\\n",
|
531 |
+
" --torch_dtype float16 \\\n",
|
532 |
+
" --fp16 True \\\n",
|
533 |
+
" --device_map auto \\\n",
|
534 |
+
" --report_to tensorboard \\\n",
|
535 |
+
" --remove_unused_columns False \\\n",
|
536 |
+
" --gradient_checkpointing True \\\n",
|
537 |
+
" --cache_dir ./cache"
|
538 |
+
],
|
539 |
+
"metadata": {
|
540 |
+
"collapsed": false
|
541 |
+
}
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"cell_type": "code",
|
545 |
+
"execution_count": null,
|
546 |
+
"outputs": [],
|
547 |
+
"source": [
|
548 |
+
"%ls -lh outputs-dpo-v1"
|
549 |
+
],
|
550 |
+
"metadata": {
|
551 |
+
"collapsed": false
|
552 |
+
}
|
553 |
+
},
|
554 |
+
{
|
555 |
+
"cell_type": "markdown",
|
556 |
+
"source": [
|
557 |
+
"模型训练结果:\n",
|
558 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
559 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
560 |
+
],
|
561 |
+
"metadata": {
|
562 |
+
"collapsed": false
|
563 |
+
}
|
564 |
+
},
|
565 |
+
{
|
566 |
+
"cell_type": "markdown",
|
567 |
+
"source": [
|
568 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
569 |
+
],
|
570 |
+
"metadata": {
|
571 |
+
"collapsed": false
|
572 |
+
}
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": null,
|
577 |
+
"outputs": [],
|
578 |
+
"source": [
|
579 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
580 |
+
" --base_model_name_or_path merged-sft --peft_model_path outputs-dpo-v1 --output_dir merged-dpo/"
|
581 |
+
],
|
582 |
+
"metadata": {
|
583 |
+
"collapsed": false
|
584 |
+
}
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"cell_type": "code",
|
588 |
+
"execution_count": null,
|
589 |
+
"outputs": [],
|
590 |
+
"source": [
|
591 |
+
"%ls -lh merged-dpo/"
|
592 |
+
],
|
593 |
+
"metadata": {
|
594 |
+
"collapsed": false
|
595 |
+
}
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"execution_count": null,
|
600 |
+
"outputs": [],
|
601 |
+
"source": [
|
602 |
+
"%cat merged-dpo/config.json"
|
603 |
+
],
|
604 |
+
"metadata": {
|
605 |
+
"collapsed": false
|
606 |
+
}
|
607 |
+
},
|
608 |
+
{
|
609 |
+
"cell_type": "markdown",
|
610 |
+
"source": [
|
611 |
+
"Stage3 偏好建模第一次训练完成。"
|
612 |
+
],
|
613 |
+
"metadata": {
|
614 |
+
"collapsed": false
|
615 |
+
}
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "markdown",
|
619 |
+
"source": [
|
620 |
+
"**至此一个完整的训练流程演示完成。**"
|
621 |
+
],
|
622 |
+
"metadata": {
|
623 |
+
"collapsed": false
|
624 |
+
}
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"cell_type": "code",
|
628 |
+
"execution_count": null,
|
629 |
+
"outputs": [],
|
630 |
+
"source": [],
|
631 |
+
"metadata": {
|
632 |
+
"collapsed": false,
|
633 |
+
"ExecuteTime": {
|
634 |
+
"start_time": "2023-06-26T12:34:29.620609Z",
|
635 |
+
"end_time": "2023-06-26T12:34:29.658428Z"
|
636 |
+
}
|
637 |
+
}
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "markdown",
|
641 |
+
"source": [
|
642 |
+
"# Test"
|
643 |
+
],
|
644 |
+
"metadata": {
|
645 |
+
"collapsed": false
|
646 |
+
}
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"cell_type": "code",
|
650 |
+
"execution_count": null,
|
651 |
+
"outputs": [],
|
652 |
+
"source": [
|
653 |
+
"!python inference.py --model_type bloom --base_model merged-dpo --interactive"
|
654 |
+
],
|
655 |
+
"metadata": {
|
656 |
+
"collapsed": false,
|
657 |
+
"ExecuteTime": {
|
658 |
+
"start_time": "2023-06-26T12:34:47.802087Z",
|
659 |
+
"end_time": "2023-06-26T12:35:00.864463Z"
|
660 |
+
}
|
661 |
+
}
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"cell_type": "markdown",
|
665 |
+
"source": [
|
666 |
+
"Input:介绍下南京\n",
|
667 |
+
"Response: 南京市位于江苏省西南部,是全国首批历史文化名城、国家中心城市和自由贸易试验区。\n",
|
668 |
+
"\n",
|
669 |
+
"完。\n"
|
670 |
+
],
|
671 |
+
"metadata": {
|
672 |
+
"collapsed": false
|
673 |
+
}
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"cell_type": "code",
|
677 |
+
"execution_count": null,
|
678 |
+
"outputs": [],
|
679 |
+
"source": [],
|
680 |
+
"metadata": {
|
681 |
+
"collapsed": false
|
682 |
+
}
|
683 |
+
}
|
684 |
+
],
|
685 |
+
"metadata": {
|
686 |
+
"kernelspec": {
|
687 |
+
"name": "python3",
|
688 |
+
"language": "python",
|
689 |
+
"display_name": "Python 3"
|
690 |
+
},
|
691 |
+
"language_info": {
|
692 |
+
"codemirror_mode": {
|
693 |
+
"name": "ipython",
|
694 |
+
"version": 3
|
695 |
+
},
|
696 |
+
"file_extension": ".py",
|
697 |
+
"mimetype": "text/x-python",
|
698 |
+
"name": "python",
|
699 |
+
"nbconvert_exporter": "python",
|
700 |
+
"pygments_lexer": "ipython3",
|
701 |
+
"version": "3.8.13"
|
702 |
+
},
|
703 |
+
"vscode": {
|
704 |
+
"interpreter": {
|
705 |
+
"hash": "f34eed0bebedfc4b6ee51ced43d2c030fe3b92f13c149d072205ca200a67b1ec"
|
706 |
+
}
|
707 |
+
}
|
708 |
+
},
|
709 |
+
"nbformat": 4,
|
710 |
+
"nbformat_minor": 4
|
711 |
+
}
|
MedicalGPT-main/run_training_pipeline.ipynb
ADDED
@@ -0,0 +1,917 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# Training Pipeline\n",
|
7 |
+
"[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) | [Open In Colab](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb)"
|
8 |
+
],
|
9 |
+
"metadata": {
|
10 |
+
"collapsed": false
|
11 |
+
}
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"tags": []
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"# Stage 1: Continue Pretraining\n",
|
20 |
+
"\n",
|
21 |
+
"第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文本数据上二次预训练GPT模型,以注入领域知识\n",
|
22 |
+
"\n",
|
23 |
+
"| Stage 1: Continue Pretraining | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"#### 说明:\n",
|
31 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
32 |
+
"\n",
|
33 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m`\n",
|
34 |
+
"2. 数据集:PT阶段使用的是中文天龙八部小说部分文本和英文书籍部分文本,位于`data/pretrain`文件夹"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"metadata": {},
|
40 |
+
"source": [
|
41 |
+
"## 配置运行环境\n",
|
42 |
+
"\n",
|
43 |
+
"本地执行可注释以下配置环境的命令,colab执行要打开注释,用于配置环境\n",
|
44 |
+
"\n",
|
45 |
+
"colab建议使用T4 GPU训练,设置方式:`代码执行程序 -> 更改运行时类型 -> 运行时类型:Python3,硬件加速器:GPU,GPU类型:T4 -> 保存`\n",
|
46 |
+
"\n",
|
47 |
+
"步骤:\n",
|
48 |
+
"1. 下载最新代码到本地\n",
|
49 |
+
"2. 安装依赖包\n",
|
50 |
+
"\n",
|
51 |
+
"依赖包如下,保证最新版本:\n",
|
52 |
+
"\n",
|
53 |
+
"```\n",
|
54 |
+
"loguru\n",
|
55 |
+
"transformers\n",
|
56 |
+
"sentencepiece\n",
|
57 |
+
"datasets\n",
|
58 |
+
"tensorboard\n",
|
59 |
+
"tqdm\n",
|
60 |
+
"peft\n",
|
61 |
+
"trl\n",
|
62 |
+
"```"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [],
|
70 |
+
"source": [
|
71 |
+
"!git clone --depth 1 https://github.com/shibing624/MedicalGPT.git\n",
|
72 |
+
"%cd MedicalGPT\n",
|
73 |
+
"%ls\n",
|
74 |
+
"!pip install -r requirements.txt"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "markdown",
|
79 |
+
"metadata": {},
|
80 |
+
"source": [
|
81 |
+
"## Stage1 咱们开始吧\n",
|
82 |
+
"\n",
|
83 |
+
"训练步骤如下:\n",
|
84 |
+
"\n",
|
85 |
+
"1. 确认训练集\n",
|
86 |
+
"2. 执行训练脚本\n",
|
87 |
+
"\n",
|
88 |
+
"训练脚本的执行逻辑如下:\n",
|
89 |
+
"1. 导入依赖包\n",
|
90 |
+
"2. 设置参数\n",
|
91 |
+
"3. 定义各函数并加载训练集\n",
|
92 |
+
"4. 加载模型和tokenizer\n",
|
93 |
+
"5. 开始训练并评估\n",
|
94 |
+
"6. 查看训练结果\n",
|
95 |
+
"\n",
|
96 |
+
"**以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的**"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"%ls ./data/pretrain/"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"!python pretraining.py \\\n",
|
114 |
+
" --model_type bloom \\\n",
|
115 |
+
" --model_name_or_path bigscience/bloomz-560m \\\n",
|
116 |
+
" --train_file_dir ./data/pretrain \\\n",
|
117 |
+
" --validation_file_dir ./data/pretrain \\\n",
|
118 |
+
" --per_device_train_batch_size 3 \\\n",
|
119 |
+
" --per_device_eval_batch_size 3 \\\n",
|
120 |
+
" --do_train \\\n",
|
121 |
+
" --do_eval \\\n",
|
122 |
+
" --use_peft True \\\n",
|
123 |
+
" --seed 42 \\\n",
|
124 |
+
" --fp16 \\\n",
|
125 |
+
" --max_train_samples 10000 \\\n",
|
126 |
+
" --max_eval_samples 10 \\\n",
|
127 |
+
" --num_train_epochs 1 \\\n",
|
128 |
+
" --learning_rate 2e-4 \\\n",
|
129 |
+
" --warmup_ratio 0.05 \\\n",
|
130 |
+
" --weight_decay 0.01 \\\n",
|
131 |
+
" --logging_strategy steps \\\n",
|
132 |
+
" --logging_steps 10 \\\n",
|
133 |
+
" --eval_steps 50 \\\n",
|
134 |
+
" --evaluation_strategy steps \\\n",
|
135 |
+
" --save_steps 500 \\\n",
|
136 |
+
" --save_strategy steps \\\n",
|
137 |
+
" --save_total_limit 3 \\\n",
|
138 |
+
" --gradient_accumulation_steps 1 \\\n",
|
139 |
+
" --preprocessing_num_workers 1 \\\n",
|
140 |
+
" --block_size 1024 \\\n",
|
141 |
+
" --output_dir outputs-pt-v1 \\\n",
|
142 |
+
" --overwrite_output_dir \\\n",
|
143 |
+
" --ddp_timeout 30000 \\\n",
|
144 |
+
" --logging_first_step True \\\n",
|
145 |
+
" --target_modules all \\\n",
|
146 |
+
" --lora_rank 8 \\\n",
|
147 |
+
" --lora_alpha 16 \\\n",
|
148 |
+
" --lora_dropout 0.05 \\\n",
|
149 |
+
" --torch_dtype float16 \\\n",
|
150 |
+
" --device_map auto \\\n",
|
151 |
+
" --report_to tensorboard \\\n",
|
152 |
+
" --ddp_find_unused_parameters False \\\n",
|
153 |
+
" --gradient_checkpointing True"
|
154 |
+
],
|
155 |
+
"metadata": {
|
156 |
+
"collapsed": false
|
157 |
+
}
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": null,
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [],
|
164 |
+
"source": [
|
165 |
+
"%ls -lh outputs-pt-v1"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "markdown",
|
170 |
+
"metadata": {},
|
171 |
+
"source": [
|
172 |
+
"模型训练结果:\n",
|
173 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
174 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
175 |
+
]
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "markdown",
|
179 |
+
"source": [
|
180 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
181 |
+
],
|
182 |
+
"metadata": {
|
183 |
+
"collapsed": false
|
184 |
+
}
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": null,
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
192 |
+
" --base_model_name_or_path bigscience/bloomz-560m --peft_model_path outputs-pt-v1 --output_dir merged-pt/"
|
193 |
+
],
|
194 |
+
"metadata": {
|
195 |
+
"collapsed": false
|
196 |
+
}
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": null,
|
201 |
+
"outputs": [],
|
202 |
+
"source": [
|
203 |
+
"%ls -lh merged-pt/"
|
204 |
+
],
|
205 |
+
"metadata": {
|
206 |
+
"collapsed": false
|
207 |
+
}
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": null,
|
212 |
+
"outputs": [],
|
213 |
+
"source": [
|
214 |
+
"%cat merged-pt/config.json"
|
215 |
+
],
|
216 |
+
"metadata": {
|
217 |
+
"collapsed": false
|
218 |
+
}
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "markdown",
|
222 |
+
"metadata": {},
|
223 |
+
"source": [
|
224 |
+
"Stage1 增量预训练完成。"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": null,
|
230 |
+
"metadata": {
|
231 |
+
"ExecuteTime": {
|
232 |
+
"start_time": "2023-06-15T13:56:17.032821Z",
|
233 |
+
"end_time": "2023-06-15T13:56:17.081153Z"
|
234 |
+
}
|
235 |
+
},
|
236 |
+
"outputs": [],
|
237 |
+
"source": []
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "markdown",
|
241 |
+
"source": [
|
242 |
+
"# Stage 2: Supervised FineTuning\n",
|
243 |
+
"\n",
|
244 |
+
"第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图\n",
|
245 |
+
"\n",
|
246 |
+
"| Stage 2: Supervised Fine-tuning | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |"
|
247 |
+
],
|
248 |
+
"metadata": {
|
249 |
+
"collapsed": false
|
250 |
+
}
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "markdown",
|
254 |
+
"source": [
|
255 |
+
"#### 说明:\n",
|
256 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
257 |
+
"\n",
|
258 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage1得到的预训练模型\n",
|
259 |
+
"2. 数据集:SFT阶段使用的是使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
|
260 |
+
],
|
261 |
+
"metadata": {
|
262 |
+
"collapsed": false
|
263 |
+
}
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "markdown",
|
267 |
+
"source": [
|
268 |
+
"## Stage2 咱们开始吧\n",
|
269 |
+
"\n",
|
270 |
+
"训练步骤如下:\n",
|
271 |
+
"\n",
|
272 |
+
"1. 确认训练集\n",
|
273 |
+
"2. 执行训练脚本\n",
|
274 |
+
"\n",
|
275 |
+
"训练脚本的执行逻辑如下:\n",
|
276 |
+
"1. 导入依赖包\n",
|
277 |
+
"2. 设置参数\n",
|
278 |
+
"3. 定义各函数并加载训练集\n",
|
279 |
+
"4. 加载模型和tokenizer\n",
|
280 |
+
"5. 开始训练并评估\n",
|
281 |
+
"6. 查看训练结果"
|
282 |
+
],
|
283 |
+
"metadata": {
|
284 |
+
"collapsed": false
|
285 |
+
}
|
286 |
+
},
|
287 |
+
{
|
288 |
+
"cell_type": "code",
|
289 |
+
"execution_count": null,
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"%ls ./data/finetune"
|
293 |
+
],
|
294 |
+
"metadata": {
|
295 |
+
"collapsed": false,
|
296 |
+
"ExecuteTime": {
|
297 |
+
"start_time": "2023-06-15T13:58:38.778132Z",
|
298 |
+
"end_time": "2023-06-15T13:58:38.966506Z"
|
299 |
+
}
|
300 |
+
}
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": null,
|
305 |
+
"outputs": [],
|
306 |
+
"source": [
|
307 |
+
"!python supervised_finetuning.py \\\n",
|
308 |
+
" --model_type bloom \\\n",
|
309 |
+
" --model_name_or_path merged-pt \\\n",
|
310 |
+
" --train_file_dir ./data/finetune \\\n",
|
311 |
+
" --validation_file_dir ./data/finetune \\\n",
|
312 |
+
" --per_device_train_batch_size 4 \\\n",
|
313 |
+
" --per_device_eval_batch_size 4 \\\n",
|
314 |
+
" --do_train \\\n",
|
315 |
+
" --do_eval \\\n",
|
316 |
+
" --use_peft True \\\n",
|
317 |
+
" --fp16 \\\n",
|
318 |
+
" --max_train_samples 1000 \\\n",
|
319 |
+
" --max_eval_samples 10 \\\n",
|
320 |
+
" --num_train_epochs 1 \\\n",
|
321 |
+
" --learning_rate 2e-5 \\\n",
|
322 |
+
" --warmup_ratio 0.05 \\\n",
|
323 |
+
" --weight_decay 0.05 \\\n",
|
324 |
+
" --logging_strategy steps \\\n",
|
325 |
+
" --logging_steps 10 \\\n",
|
326 |
+
" --eval_steps 50 \\\n",
|
327 |
+
" --evaluation_strategy steps \\\n",
|
328 |
+
" --save_steps 500 \\\n",
|
329 |
+
" --save_strategy steps \\\n",
|
330 |
+
" --save_total_limit 3 \\\n",
|
331 |
+
" --gradient_accumulation_steps 1 \\\n",
|
332 |
+
" --preprocessing_num_workers 1 \\\n",
|
333 |
+
" --output_dir outputs-sft-v1 \\\n",
|
334 |
+
" --overwrite_output_dir \\\n",
|
335 |
+
" --ddp_timeout 30000 \\\n",
|
336 |
+
" --logging_first_step True \\\n",
|
337 |
+
" --target_modules all \\\n",
|
338 |
+
" --lora_rank 8 \\\n",
|
339 |
+
" --lora_alpha 16 \\\n",
|
340 |
+
" --lora_dropout 0.05 \\\n",
|
341 |
+
" --torch_dtype float16 \\\n",
|
342 |
+
" --device_map auto \\\n",
|
343 |
+
" --report_to tensorboard \\\n",
|
344 |
+
" --ddp_find_unused_parameters False \\\n",
|
345 |
+
" --gradient_checkpointing True"
|
346 |
+
],
|
347 |
+
"metadata": {
|
348 |
+
"collapsed": false
|
349 |
+
}
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"cell_type": "code",
|
353 |
+
"execution_count": null,
|
354 |
+
"outputs": [],
|
355 |
+
"source": [
|
356 |
+
"%ls -lh outputs-sft-v1"
|
357 |
+
],
|
358 |
+
"metadata": {
|
359 |
+
"collapsed": false
|
360 |
+
}
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "markdown",
|
364 |
+
"source": [
|
365 |
+
"模型训练结果:\n",
|
366 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
367 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
368 |
+
],
|
369 |
+
"metadata": {
|
370 |
+
"collapsed": false
|
371 |
+
}
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "markdown",
|
375 |
+
"source": [
|
376 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
377 |
+
],
|
378 |
+
"metadata": {
|
379 |
+
"collapsed": false
|
380 |
+
}
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "code",
|
384 |
+
"execution_count": null,
|
385 |
+
"outputs": [],
|
386 |
+
"source": [
|
387 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
388 |
+
" --base_model_name_or_path merged-pt --peft_model_path outputs-sft-v1 --output_dir merged-sft/"
|
389 |
+
],
|
390 |
+
"metadata": {
|
391 |
+
"collapsed": false
|
392 |
+
}
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"cell_type": "code",
|
396 |
+
"execution_count": null,
|
397 |
+
"outputs": [],
|
398 |
+
"source": [
|
399 |
+
"%ls -lh merged-sft/"
|
400 |
+
],
|
401 |
+
"metadata": {
|
402 |
+
"collapsed": false
|
403 |
+
}
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"cell_type": "code",
|
407 |
+
"execution_count": null,
|
408 |
+
"outputs": [],
|
409 |
+
"source": [
|
410 |
+
"%cat merged-sft/config.json"
|
411 |
+
],
|
412 |
+
"metadata": {
|
413 |
+
"collapsed": false
|
414 |
+
}
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "markdown",
|
418 |
+
"source": [
|
419 |
+
"Stage2 SFT训练完成。"
|
420 |
+
],
|
421 |
+
"metadata": {
|
422 |
+
"collapsed": false
|
423 |
+
}
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"execution_count": null,
|
428 |
+
"outputs": [],
|
429 |
+
"source": [],
|
430 |
+
"metadata": {
|
431 |
+
"collapsed": false,
|
432 |
+
"ExecuteTime": {
|
433 |
+
"start_time": "2023-06-15T14:07:40.731186Z",
|
434 |
+
"end_time": "2023-06-15T14:07:40.752635Z"
|
435 |
+
}
|
436 |
+
}
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"cell_type": "markdown",
|
440 |
+
"source": [
|
441 |
+
"# Stage 3: Reward Modeling\n",
|
442 |
+
"\n",
|
443 |
+
"第三阶段:RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来对齐人类偏好,主要是\"HHH\"原则,具体是\"helpful, honest, harmless\"\n",
|
444 |
+
"\n",
|
445 |
+
"| Stage 3: Reward Modeling | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |"
|
446 |
+
],
|
447 |
+
"metadata": {
|
448 |
+
"collapsed": false
|
449 |
+
}
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"cell_type": "markdown",
|
453 |
+
"source": [
|
454 |
+
"#### 说明:\n",
|
455 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
456 |
+
"\n",
|
457 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
|
458 |
+
"2. 数据集:RM阶段使用的是医疗reward数据,抽样了500条,位于`data/reward`文件夹"
|
459 |
+
],
|
460 |
+
"metadata": {
|
461 |
+
"collapsed": false
|
462 |
+
}
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "markdown",
|
466 |
+
"source": [
|
467 |
+
"## Stage3 咱们开始吧\n",
|
468 |
+
"\n",
|
469 |
+
"训练步骤如下:\n",
|
470 |
+
"\n",
|
471 |
+
"1. 确认训练集\n",
|
472 |
+
"2. 执行训练脚本\n",
|
473 |
+
"\n",
|
474 |
+
"训练脚本的执行逻辑如下:\n",
|
475 |
+
"1. 导入依赖包\n",
|
476 |
+
"2. 设置参数\n",
|
477 |
+
"3. 定义各函数并加载训练集\n",
|
478 |
+
"4. 加载模型和tokenizer\n",
|
479 |
+
"5. 开始训练并评估\n",
|
480 |
+
"6. 查看训练结果"
|
481 |
+
],
|
482 |
+
"metadata": {
|
483 |
+
"collapsed": false
|
484 |
+
}
|
485 |
+
},
|
486 |
+
{
|
487 |
+
"cell_type": "code",
|
488 |
+
"execution_count": null,
|
489 |
+
"outputs": [],
|
490 |
+
"source": [
|
491 |
+
"%ls ./data/reward/"
|
492 |
+
],
|
493 |
+
"metadata": {
|
494 |
+
"collapsed": false
|
495 |
+
}
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": null,
|
500 |
+
"outputs": [],
|
501 |
+
"source": [
|
502 |
+
"!python reward_modeling.py \\\n",
|
503 |
+
" --model_type bloom \\\n",
|
504 |
+
" --model_name_or_path merged-sft \\\n",
|
505 |
+
" --train_file_dir ./data/reward \\\n",
|
506 |
+
" --validation_file_dir ./data/reward \\\n",
|
507 |
+
" --per_device_train_batch_size 3 \\\n",
|
508 |
+
" --per_device_eval_batch_size 1 \\\n",
|
509 |
+
" --do_train \\\n",
|
510 |
+
" --use_peft True \\\n",
|
511 |
+
" --seed 42 \\\n",
|
512 |
+
" --max_train_samples 1000 \\\n",
|
513 |
+
" --max_eval_samples 10 \\\n",
|
514 |
+
" --num_train_epochs 1 \\\n",
|
515 |
+
" --learning_rate 2e-5 \\\n",
|
516 |
+
" --warmup_ratio 0.05 \\\n",
|
517 |
+
" --weight_decay 0.001 \\\n",
|
518 |
+
" --logging_strategy steps \\\n",
|
519 |
+
" --logging_steps 10 \\\n",
|
520 |
+
" --eval_steps 50 \\\n",
|
521 |
+
" --evaluation_strategy steps \\\n",
|
522 |
+
" --save_steps 500 \\\n",
|
523 |
+
" --save_strategy steps \\\n",
|
524 |
+
" --save_total_limit 3 \\\n",
|
525 |
+
" --max_source_length 256 \\\n",
|
526 |
+
" --max_target_length 256 \\\n",
|
527 |
+
" --output_dir outputs-rm-v1 \\\n",
|
528 |
+
" --overwrite_output_dir \\\n",
|
529 |
+
" --ddp_timeout 30000 \\\n",
|
530 |
+
" --logging_first_step True \\\n",
|
531 |
+
" --target_modules all \\\n",
|
532 |
+
" --lora_rank 8 \\\n",
|
533 |
+
" --lora_alpha 16 \\\n",
|
534 |
+
" --lora_dropout 0.05 \\\n",
|
535 |
+
" --torch_dtype float32 \\\n",
|
536 |
+
" --device_map auto \\\n",
|
537 |
+
" --report_to tensorboard \\\n",
|
538 |
+
" --ddp_find_unused_parameters False \\\n",
|
539 |
+
" --remove_unused_columns False \\\n",
|
540 |
+
" --gradient_checkpointing True"
|
541 |
+
],
|
542 |
+
"metadata": {
|
543 |
+
"collapsed": false
|
544 |
+
}
|
545 |
+
},
|
546 |
+
{
|
547 |
+
"cell_type": "code",
|
548 |
+
"execution_count": null,
|
549 |
+
"outputs": [],
|
550 |
+
"source": [
|
551 |
+
"%ls -lh outputs-rm-v1"
|
552 |
+
],
|
553 |
+
"metadata": {
|
554 |
+
"collapsed": false
|
555 |
+
}
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"cell_type": "markdown",
|
559 |
+
"source": [
|
560 |
+
"模型训练结果:\n",
|
561 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
562 |
+
"- 日志保存在`output_dir/runs`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/runs --host 0.0.0.0 --port 8009`"
|
563 |
+
],
|
564 |
+
"metadata": {
|
565 |
+
"collapsed": false
|
566 |
+
}
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"cell_type": "markdown",
|
570 |
+
"source": [
|
571 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
572 |
+
],
|
573 |
+
"metadata": {
|
574 |
+
"collapsed": false
|
575 |
+
}
|
576 |
+
},
|
577 |
+
{
|
578 |
+
"cell_type": "code",
|
579 |
+
"execution_count": null,
|
580 |
+
"outputs": [],
|
581 |
+
"source": [
|
582 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
583 |
+
" --base_model_name_or_path merged-sft --peft_model_path outputs-rm-v1 --output_dir merged-rm/"
|
584 |
+
],
|
585 |
+
"metadata": {
|
586 |
+
"collapsed": false
|
587 |
+
}
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"cell_type": "code",
|
591 |
+
"execution_count": null,
|
592 |
+
"outputs": [],
|
593 |
+
"source": [
|
594 |
+
"%ls -lh merged-rm/"
|
595 |
+
],
|
596 |
+
"metadata": {
|
597 |
+
"collapsed": false
|
598 |
+
}
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"cell_type": "code",
|
602 |
+
"execution_count": null,
|
603 |
+
"outputs": [],
|
604 |
+
"source": [
|
605 |
+
"%cat merged-rm/config.json"
|
606 |
+
],
|
607 |
+
"metadata": {
|
608 |
+
"collapsed": false
|
609 |
+
}
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "markdown",
|
613 |
+
"source": [
|
614 |
+
"Stage3 奖励建模第一次训练完成。"
|
615 |
+
],
|
616 |
+
"metadata": {
|
617 |
+
"collapsed": false
|
618 |
+
}
|
619 |
+
},
|
620 |
+
{
|
621 |
+
"cell_type": "code",
|
622 |
+
"execution_count": null,
|
623 |
+
"outputs": [],
|
624 |
+
"source": [],
|
625 |
+
"metadata": {
|
626 |
+
"collapsed": false,
|
627 |
+
"ExecuteTime": {
|
628 |
+
"start_time": "2023-06-15T14:12:09.464881Z",
|
629 |
+
"end_time": "2023-06-15T14:12:09.472414Z"
|
630 |
+
}
|
631 |
+
}
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"cell_type": "markdown",
|
635 |
+
"source": [
|
636 |
+
"# Stage 4: Reinforcement Learning Training\n",
|
637 |
+
"\n",
|
638 |
+
"第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF),用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本\n",
|
639 |
+
"\n",
|
640 |
+
"| Stage 4: Reinforcement Learning | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |\n"
|
641 |
+
],
|
642 |
+
"metadata": {
|
643 |
+
"collapsed": false
|
644 |
+
}
|
645 |
+
},
|
646 |
+
{
|
647 |
+
"cell_type": "markdown",
|
648 |
+
"source": [
|
649 |
+
"#### 说明:\n",
|
650 |
+
"以下 notebook/colab 代码为了快速验证训练代码可用,我们使用了小size的生成模型、奖励模型和小样本数据集,实际使用时,需要使用更大的模型和数据集,以获得更好的效果。\n",
|
651 |
+
"\n",
|
652 |
+
"1. 生成模型:使用的是Bloom的`bigscience/bloomz-560m` 或者 Stage2得到的SFT模型\n",
|
653 |
+
"2. 奖励模型:使用的是`OpenAssistant/reward-model-deberta-v3-large-v2` 或者 Stage3得到的BERT类或者GPT类奖励模型\n",
|
654 |
+
"3. 数据集:RL阶段的数据可以复用SFT的数据集,使用的是Belle的1千条抽样数据,位于`data/finetune`文件夹"
|
655 |
+
],
|
656 |
+
"metadata": {
|
657 |
+
"collapsed": false
|
658 |
+
}
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"cell_type": "markdown",
|
662 |
+
"source": [
|
663 |
+
"## Stage4 咱们开始吧\n",
|
664 |
+
"\n",
|
665 |
+
"训练步骤如下:\n",
|
666 |
+
"\n",
|
667 |
+
"1. 确认训练集\n",
|
668 |
+
"2. 执行训练脚本\n",
|
669 |
+
"\n",
|
670 |
+
"训练脚本的执行逻辑如下:\n",
|
671 |
+
"1. 导入依赖包\n",
|
672 |
+
"2. 设置参数\n",
|
673 |
+
"3. 定义各函数并加载训练集\n",
|
674 |
+
"4. 加载生成模型和tokenizer,加载奖励模型和其tokenizer\n",
|
675 |
+
"5. 开始训练并评估\n",
|
676 |
+
"6. 查看训练结果\n",
|
677 |
+
"\n",
|
678 |
+
"以下参数可以根据你的GPU实际情况修改,当前参数是根据Colab的T4单卡GPU(16GB显存)配置的。"
|
679 |
+
],
|
680 |
+
"metadata": {
|
681 |
+
"collapsed": false
|
682 |
+
}
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"cell_type": "code",
|
686 |
+
"execution_count": null,
|
687 |
+
"outputs": [],
|
688 |
+
"source": [
|
689 |
+
"%ls ./data/finetune/"
|
690 |
+
],
|
691 |
+
"metadata": {
|
692 |
+
"collapsed": false
|
693 |
+
}
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"cell_type": "code",
|
697 |
+
"execution_count": null,
|
698 |
+
"outputs": [],
|
699 |
+
"source": [
|
700 |
+
"!python rl_training.py \\\n",
|
701 |
+
" --model_type bloom \\\n",
|
702 |
+
" --model_name_or_path merged-sft \\\n",
|
703 |
+
" --reward_model_name_or_path merged-rm \\\n",
|
704 |
+
" --torch_dtype float16 \\\n",
|
705 |
+
" --device_map auto \\\n",
|
706 |
+
" --train_file_dir ./data/finetune \\\n",
|
707 |
+
" --validation_file_dir ./data/finetune \\\n",
|
708 |
+
" --batch_size 4 \\\n",
|
709 |
+
" --max_source_length 256 \\\n",
|
710 |
+
" --max_target_length 256 \\\n",
|
711 |
+
" --max_train_samples 1000 \\\n",
|
712 |
+
" --use_peft True \\\n",
|
713 |
+
" --lora_rank 8 \\\n",
|
714 |
+
" --lora_alpha 16 \\\n",
|
715 |
+
" --lora_dropout 0.05 \\\n",
|
716 |
+
" --do_train \\\n",
|
717 |
+
" --max_steps 64 \\\n",
|
718 |
+
" --learning_rate 1e-5 \\\n",
|
719 |
+
" --save_steps 50 \\\n",
|
720 |
+
" --output_dir outputs-rl-v1 \\\n",
|
721 |
+
" --early_stopping True \\\n",
|
722 |
+
" --target_kl 0.1 \\\n",
|
723 |
+
" --reward_baseline 0.0"
|
724 |
+
],
|
725 |
+
"metadata": {
|
726 |
+
"collapsed": false
|
727 |
+
}
|
728 |
+
},
|
729 |
+
{
|
730 |
+
"cell_type": "code",
|
731 |
+
"execution_count": null,
|
732 |
+
"outputs": [],
|
733 |
+
"source": [
|
734 |
+
"%ls -lh outputs-rl-v1"
|
735 |
+
],
|
736 |
+
"metadata": {
|
737 |
+
"collapsed": false
|
738 |
+
}
|
739 |
+
},
|
740 |
+
{
|
741 |
+
"cell_type": "markdown",
|
742 |
+
"source": [
|
743 |
+
"模型训练结果:\n",
|
744 |
+
"- 使用lora训练模型,则保存的lora权重是`adapter_model.bin`, lora配置文件是`adapter_config.json`,合并到base model的方法见`merge_peft_adapter.py`\n",
|
745 |
+
"- 日志保存在`output_dir/trl`目录下,可以使用tensorboard查看,启动tensorboard方式如下:`tensorboard --logdir output_dir/trl --host 0.0.0.0 --port 8009`"
|
746 |
+
],
|
747 |
+
"metadata": {
|
748 |
+
"collapsed": false
|
749 |
+
}
|
750 |
+
},
|
751 |
+
{
|
752 |
+
"cell_type": "markdown",
|
753 |
+
"source": [
|
754 |
+
"lora模型权重合并到base model,合并后的模型保存在`--output_dir`目录下,合并方法如下:"
|
755 |
+
],
|
756 |
+
"metadata": {
|
757 |
+
"collapsed": false
|
758 |
+
}
|
759 |
+
},
|
760 |
+
{
|
761 |
+
"cell_type": "code",
|
762 |
+
"execution_count": null,
|
763 |
+
"outputs": [],
|
764 |
+
"source": [
|
765 |
+
"!python merge_peft_adapter.py --model_type bloom \\\n",
|
766 |
+
" --base_model_name_or_path merged-sft --peft_model_path outputs-rl-v1 --output_dir merged-rl/"
|
767 |
+
],
|
768 |
+
"metadata": {
|
769 |
+
"collapsed": false
|
770 |
+
}
|
771 |
+
},
|
772 |
+
{
|
773 |
+
"cell_type": "code",
|
774 |
+
"execution_count": null,
|
775 |
+
"outputs": [],
|
776 |
+
"source": [
|
777 |
+
"%ls -lh merged-rl/"
|
778 |
+
],
|
779 |
+
"metadata": {
|
780 |
+
"collapsed": false
|
781 |
+
}
|
782 |
+
},
|
783 |
+
{
|
784 |
+
"cell_type": "code",
|
785 |
+
"execution_count": null,
|
786 |
+
"outputs": [],
|
787 |
+
"source": [
|
788 |
+
"%cat merged-rl/config.json"
|
789 |
+
],
|
790 |
+
"metadata": {
|
791 |
+
"collapsed": false
|
792 |
+
}
|
793 |
+
},
|
794 |
+
{
|
795 |
+
"cell_type": "markdown",
|
796 |
+
"source": [
|
797 |
+
"Stage4 RL第一次训练完成。\n",
|
798 |
+
"\n",
|
799 |
+
"**至此一个完整的4阶段训练流程演示完成。**"
|
800 |
+
],
|
801 |
+
"metadata": {
|
802 |
+
"collapsed": false
|
803 |
+
}
|
804 |
+
},
|
805 |
+
{
|
806 |
+
"cell_type": "markdown",
|
807 |
+
"source": [
|
808 |
+
"实际操作中Stage3和Stage4可以反复多次,直到RL得到的最后模型满足评估要求。\n",
|
809 |
+
"\n",
|
810 |
+
"RLHF过程可以把SFT模型当成一个初始化模型,RM模型当做指导老师,使用RL(PPO)调教SFT模型生成指导老师最满意的结果,如果小学老师满意了,我们就再训练一个中学老师,继续指导,中学老师满意了,就训练一个大学老师,这样不断迭代,使得生成模型的质量达到甚至超过人工撰写的天花板。\n",
|
811 |
+
"\n",
|
812 |
+
"RLHF训练不易,此项目提供给大家一种实现的方法和参考,希望抛砖引玉,共同促进中文开源LLM发展。"
|
813 |
+
],
|
814 |
+
"metadata": {
|
815 |
+
"collapsed": false
|
816 |
+
}
|
817 |
+
},
|
818 |
+
{
|
819 |
+
"cell_type": "markdown",
|
820 |
+
"source": [],
|
821 |
+
"metadata": {
|
822 |
+
"collapsed": false
|
823 |
+
}
|
824 |
+
},
|
825 |
+
{
|
826 |
+
"cell_type": "code",
|
827 |
+
"execution_count": null,
|
828 |
+
"outputs": [],
|
829 |
+
"source": [],
|
830 |
+
"metadata": {
|
831 |
+
"collapsed": false,
|
832 |
+
"ExecuteTime": {
|
833 |
+
"start_time": "2023-06-26T12:34:29.620609Z",
|
834 |
+
"end_time": "2023-06-26T12:34:29.658428Z"
|
835 |
+
}
|
836 |
+
}
|
837 |
+
},
|
838 |
+
{
|
839 |
+
"cell_type": "markdown",
|
840 |
+
"source": [
|
841 |
+
"# Test"
|
842 |
+
],
|
843 |
+
"metadata": {
|
844 |
+
"collapsed": false
|
845 |
+
}
|
846 |
+
},
|
847 |
+
{
|
848 |
+
"cell_type": "markdown",
|
849 |
+
"source": [],
|
850 |
+
"metadata": {
|
851 |
+
"collapsed": false
|
852 |
+
}
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "code",
|
856 |
+
"execution_count": null,
|
857 |
+
"outputs": [],
|
858 |
+
"source": [
|
859 |
+
"!python inference.py --model_type bloom --base_model merged-rl --interactive"
|
860 |
+
],
|
861 |
+
"metadata": {
|
862 |
+
"collapsed": false,
|
863 |
+
"ExecuteTime": {
|
864 |
+
"start_time": "2023-06-26T12:34:47.802087Z",
|
865 |
+
"end_time": "2023-06-26T12:35:00.864463Z"
|
866 |
+
}
|
867 |
+
}
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"cell_type": "markdown",
|
871 |
+
"source": [
|
872 |
+
"Input:介绍下南京\n",
|
873 |
+
"Response: 南京市位于江苏省西南部,是全国��批历史文化名城、国家中心城市和自由贸易试验区。\n",
|
874 |
+
"\n",
|
875 |
+
"完。\n"
|
876 |
+
],
|
877 |
+
"metadata": {
|
878 |
+
"collapsed": false
|
879 |
+
}
|
880 |
+
},
|
881 |
+
{
|
882 |
+
"cell_type": "code",
|
883 |
+
"execution_count": null,
|
884 |
+
"outputs": [],
|
885 |
+
"source": [],
|
886 |
+
"metadata": {
|
887 |
+
"collapsed": false
|
888 |
+
}
|
889 |
+
}
|
890 |
+
],
|
891 |
+
"metadata": {
|
892 |
+
"kernelspec": {
|
893 |
+
"name": "python3",
|
894 |
+
"language": "python",
|
895 |
+
"display_name": "Python 3"
|
896 |
+
},
|
897 |
+
"language_info": {
|
898 |
+
"codemirror_mode": {
|
899 |
+
"name": "ipython",
|
900 |
+
"version": 3
|
901 |
+
},
|
902 |
+
"file_extension": ".py",
|
903 |
+
"mimetype": "text/x-python",
|
904 |
+
"name": "python",
|
905 |
+
"nbconvert_exporter": "python",
|
906 |
+
"pygments_lexer": "ipython3",
|
907 |
+
"version": "3.8.13"
|
908 |
+
},
|
909 |
+
"vscode": {
|
910 |
+
"interpreter": {
|
911 |
+
"hash": "f34eed0bebedfc4b6ee51ced43d2c030fe3b92f13c149d072205ca200a67b1ec"
|
912 |
+
}
|
913 |
+
}
|
914 |
+
},
|
915 |
+
"nbformat": 4,
|
916 |
+
"nbformat_minor": 4
|
917 |
+
}
|
MedicalGPT-main/supervised_finetuning.py
ADDED
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright 2023 XuMing(xuming624@qq.com) and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
17 |
+
|
18 |
+
part of this code is adapted from https://github.com/shibing624/textgen
|
19 |
+
"""
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
from dataclasses import dataclass, field
|
23 |
+
from glob import glob
|
24 |
+
from typing import List, Optional, Dict, Sequence
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from datasets import load_dataset
|
28 |
+
from loguru import logger
|
29 |
+
from peft import LoraConfig, TaskType, get_peft_model, PeftModel, prepare_model_for_int8_training
|
30 |
+
from transformers import (
|
31 |
+
AutoConfig,
|
32 |
+
BloomForCausalLM,
|
33 |
+
AutoModel,
|
34 |
+
AutoModelForCausalLM,
|
35 |
+
LlamaTokenizer,
|
36 |
+
LlamaForCausalLM,
|
37 |
+
BloomTokenizerFast,
|
38 |
+
AutoTokenizer,
|
39 |
+
HfArgumentParser,
|
40 |
+
Trainer,
|
41 |
+
TrainingArguments,
|
42 |
+
set_seed,
|
43 |
+
BitsAndBytesConfig,
|
44 |
+
DataCollatorForSeq2Seq,
|
45 |
+
)
|
46 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
47 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
48 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
49 |
+
|
50 |
+
MODEL_CLASSES = {
|
51 |
+
"bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
|
52 |
+
"chatglm": (AutoConfig, AutoModel, AutoTokenizer),
|
53 |
+
"llama": (AutoConfig, LlamaForCausalLM, LlamaTokenizer),
|
54 |
+
"baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
55 |
+
"auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
@dataclass
|
60 |
+
class ModelArguments:
|
61 |
+
"""
|
62 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
63 |
+
"""
|
64 |
+
|
65 |
+
model_type: str = field(
|
66 |
+
default=None,
|
67 |
+
metadata={"help": "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())}
|
68 |
+
)
|
69 |
+
model_name_or_path: Optional[str] = field(
|
70 |
+
default=None,
|
71 |
+
metadata={
|
72 |
+
"help": (
|
73 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
74 |
+
)
|
75 |
+
},
|
76 |
+
)
|
77 |
+
tokenizer_name_or_path: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={
|
80 |
+
"help": (
|
81 |
+
"The tokenizer for weights initialization.Don't set if you want to train a model from scratch."
|
82 |
+
)
|
83 |
+
},
|
84 |
+
)
|
85 |
+
load_in_8bit: bool = field(default=False, metadata={"help": "Whether to load the model in 8bit mode or not."})
|
86 |
+
cache_dir: Optional[str] = field(
|
87 |
+
default=None,
|
88 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
89 |
+
)
|
90 |
+
use_fast_tokenizer: bool = field(
|
91 |
+
default=False,
|
92 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
93 |
+
)
|
94 |
+
torch_dtype: Optional[str] = field(
|
95 |
+
default="float16",
|
96 |
+
metadata={
|
97 |
+
"help": (
|
98 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
99 |
+
"dtype will be automatically derived from the model's weights."
|
100 |
+
),
|
101 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
102 |
+
},
|
103 |
+
)
|
104 |
+
device_map: Optional[str] = field(
|
105 |
+
default="auto",
|
106 |
+
metadata={"help": "Device to map model to. If `auto` is passed, the device will be selected automatically. "},
|
107 |
+
)
|
108 |
+
trust_remote_code: bool = field(
|
109 |
+
default=True,
|
110 |
+
metadata={"help": "Whether to trust remote code when loading a model from a remote checkpoint."},
|
111 |
+
)
|
112 |
+
|
113 |
+
def __post_init__(self):
|
114 |
+
if self.model_type is None:
|
115 |
+
raise ValueError(
|
116 |
+
"You must specify a valid model_type to run training. Available model types are " + ", ".join(
|
117 |
+
MODEL_CLASSES.keys()))
|
118 |
+
if self.model_name_or_path is None:
|
119 |
+
raise ValueError("You must specify a valid model_name_or_path to run training.")
|
120 |
+
|
121 |
+
|
122 |
+
@dataclass
|
123 |
+
class DataTrainingArguments:
|
124 |
+
"""
|
125 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
126 |
+
"""
|
127 |
+
|
128 |
+
dataset_name: Optional[str] = field(
|
129 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
130 |
+
)
|
131 |
+
dataset_config_name: Optional[str] = field(
|
132 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
133 |
+
)
|
134 |
+
train_file_dir: Optional[str] = field(default=None, metadata={"help": "The train jsonl data file folder."})
|
135 |
+
validation_file_dir: Optional[str] = field(default=None, metadata={"help": "The evaluation jsonl file folder."})
|
136 |
+
template_name: Optional[str] = field(default="vicuna", metadata={"help": "The prompt template name."})
|
137 |
+
max_train_samples: Optional[int] = field(
|
138 |
+
default=None,
|
139 |
+
metadata={
|
140 |
+
"help": (
|
141 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
142 |
+
"value if set."
|
143 |
+
)
|
144 |
+
},
|
145 |
+
)
|
146 |
+
max_eval_samples: Optional[int] = field(
|
147 |
+
default=None,
|
148 |
+
metadata={
|
149 |
+
"help": (
|
150 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
151 |
+
"value if set."
|
152 |
+
)
|
153 |
+
},
|
154 |
+
)
|
155 |
+
max_source_length: Optional[int] = field(default=256, metadata={"help": "Max length of prompt input text"})
|
156 |
+
max_target_length: Optional[int] = field(default=256, metadata={"help": "Max length of output text"})
|
157 |
+
ignore_pad_token_for_loss: bool = field(
|
158 |
+
default=True,
|
159 |
+
metadata={"help": "If only pad tokens should be ignored. This assumes that `config.pad_token_id` is defined."},
|
160 |
+
)
|
161 |
+
overwrite_cache: bool = field(
|
162 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
163 |
+
)
|
164 |
+
validation_split_percentage: Optional[int] = field(
|
165 |
+
default=1,
|
166 |
+
metadata={
|
167 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
168 |
+
},
|
169 |
+
)
|
170 |
+
preprocessing_num_workers: Optional[int] = field(
|
171 |
+
default=None,
|
172 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
173 |
+
)
|
174 |
+
|
175 |
+
def __post_init__(self):
|
176 |
+
if self.max_train_samples is not None and 0 < self.max_train_samples <= 1000:
|
177 |
+
logger.warning("You may set max_train_samples = -1 to run all samples in production.")
|
178 |
+
if self.max_source_length < 30:
|
179 |
+
raise ValueError("You must specify a valid max_source_length >= 30 to run training.")
|
180 |
+
|
181 |
+
|
182 |
+
@dataclass
|
183 |
+
class PeftArguments(TrainingArguments):
|
184 |
+
use_peft: bool = field(default=True, metadata={"help": "Whether to use peft"})
|
185 |
+
target_modules: Optional[str] = field(default="all")
|
186 |
+
lora_rank: Optional[int] = field(default=8)
|
187 |
+
lora_dropout: Optional[float] = field(default=0.05)
|
188 |
+
lora_alpha: Optional[float] = field(default=32.0)
|
189 |
+
modules_to_save: Optional[str] = field(default=None)
|
190 |
+
peft_path: Optional[str] = field(default=None, metadata={"help": "The path to the peft model"})
|
191 |
+
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})
|
192 |
+
|
193 |
+
|
194 |
+
class CastOutputToFloat(torch.nn.Sequential):
|
195 |
+
"""Cast the output of the model to float"""
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
return super().forward(x).to(torch.float32)
|
199 |
+
|
200 |
+
|
201 |
+
@dataclass
|
202 |
+
class Conversation:
|
203 |
+
"""A class that manages prompt templates and keeps all conversation history."""
|
204 |
+
|
205 |
+
# The name of this template
|
206 |
+
name: str
|
207 |
+
# The system prompt
|
208 |
+
system_prompt: str
|
209 |
+
# All messages. format: list of [question, answer]
|
210 |
+
messages: Optional[List[Sequence[str]]]
|
211 |
+
# The roles of the speakers
|
212 |
+
roles: Optional[Sequence[str]]
|
213 |
+
# Conversation prompt
|
214 |
+
prompt: str
|
215 |
+
# Separator
|
216 |
+
sep: str
|
217 |
+
# Stop token, default is tokenizer.eos_token
|
218 |
+
stop_str: Optional[str] = "</s>"
|
219 |
+
|
220 |
+
def get_prompt(
|
221 |
+
self,
|
222 |
+
messages: Optional[List[Sequence[str]]] = None,
|
223 |
+
system_prompt: Optional[str] = ""
|
224 |
+
) -> str:
|
225 |
+
"""
|
226 |
+
Returns a string containing prompt without response.
|
227 |
+
"""
|
228 |
+
return "".join(self._format_example(messages, system_prompt))
|
229 |
+
|
230 |
+
def get_dialog(
|
231 |
+
self,
|
232 |
+
messages: Optional[List[Sequence[str]]] = None,
|
233 |
+
system_prompt: Optional[str] = ""
|
234 |
+
) -> List[str]:
|
235 |
+
"""
|
236 |
+
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
237 |
+
"""
|
238 |
+
return self._format_example(messages, system_prompt)
|
239 |
+
|
240 |
+
def _format_example(
|
241 |
+
self,
|
242 |
+
messages: Optional[List[Sequence[str]]] = None,
|
243 |
+
system_prompt: Optional[str] = ""
|
244 |
+
) -> List[str]:
|
245 |
+
system_prompt = system_prompt or self.system_prompt
|
246 |
+
system_prompt = system_prompt + self.sep if system_prompt else "" # add separator for non-empty system prompt
|
247 |
+
messages = messages or self.messages
|
248 |
+
convs = []
|
249 |
+
for turn_idx, [user_query, bot_resp] in enumerate(messages):
|
250 |
+
if turn_idx == 0:
|
251 |
+
convs.append(system_prompt + self.prompt.format(query=user_query))
|
252 |
+
convs.append(bot_resp)
|
253 |
+
else:
|
254 |
+
convs.append(self.sep + self.prompt.format(query=user_query))
|
255 |
+
convs.append(bot_resp)
|
256 |
+
return convs
|
257 |
+
|
258 |
+
def append_message(self, query: str, answer: str):
|
259 |
+
"""Append a new message."""
|
260 |
+
self.messages.append([query, answer])
|
261 |
+
|
262 |
+
|
263 |
+
# A global registry for all conversation templates
|
264 |
+
conv_templates: Dict[str, Conversation] = {}
|
265 |
+
|
266 |
+
|
267 |
+
def register_conv_template(template: Conversation):
|
268 |
+
"""Register a new conversation template."""
|
269 |
+
conv_templates[template.name] = template
|
270 |
+
|
271 |
+
|
272 |
+
"""Vicuna v1.1 template
|
273 |
+
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
274 |
+
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
275 |
+
"""
|
276 |
+
register_conv_template(
|
277 |
+
Conversation(
|
278 |
+
name="vicuna",
|
279 |
+
system_prompt="A chat between a curious user and an artificial intelligence assistant. "
|
280 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
281 |
+
messages=[],
|
282 |
+
roles=("USER", "ASSISTANT"),
|
283 |
+
prompt="USER: {query} ASSISTANT: ",
|
284 |
+
sep="</s>",
|
285 |
+
)
|
286 |
+
)
|
287 |
+
|
288 |
+
"""Alpaca template"""
|
289 |
+
register_conv_template(
|
290 |
+
Conversation(
|
291 |
+
name="alpaca",
|
292 |
+
system_prompt="Below is an instruction that describes a task. "
|
293 |
+
"Write a response that appropriately completes the request.",
|
294 |
+
messages=[],
|
295 |
+
roles=("### Instruction", "### Response"),
|
296 |
+
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
297 |
+
sep="\n\n",
|
298 |
+
)
|
299 |
+
)
|
300 |
+
|
301 |
+
"""Baichuan-13B-Chat template
|
302 |
+
source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
|
303 |
+
Support: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
304 |
+
"""
|
305 |
+
register_conv_template(
|
306 |
+
Conversation(
|
307 |
+
name="baichuan-chat",
|
308 |
+
system_prompt="",
|
309 |
+
messages=[],
|
310 |
+
roles=("<reserved_102>", "<reserved_103>"),
|
311 |
+
prompt=" <reserved_102> {query} <reserved_103> ",
|
312 |
+
sep="</s>",
|
313 |
+
)
|
314 |
+
)
|
315 |
+
|
316 |
+
"""ziya template"""
|
317 |
+
register_conv_template(
|
318 |
+
Conversation(
|
319 |
+
name="ziya",
|
320 |
+
system_prompt="",
|
321 |
+
messages=[],
|
322 |
+
roles=("<human>", "<bot>"),
|
323 |
+
prompt="<human>:{query}\n<bot>:",
|
324 |
+
sep="\n",
|
325 |
+
)
|
326 |
+
)
|
327 |
+
|
328 |
+
"""Linly template"""
|
329 |
+
register_conv_template(
|
330 |
+
Conversation(
|
331 |
+
name="linly",
|
332 |
+
system_prompt="",
|
333 |
+
messages=[],
|
334 |
+
roles=("User", "Bot"),
|
335 |
+
prompt="User: {query}\nBot: ",
|
336 |
+
sep="\n",
|
337 |
+
)
|
338 |
+
)
|
339 |
+
|
340 |
+
"""ChatGLM1 template
|
341 |
+
source: https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1307
|
342 |
+
"""
|
343 |
+
register_conv_template(
|
344 |
+
Conversation(
|
345 |
+
name="chatglm",
|
346 |
+
system_prompt="",
|
347 |
+
messages=[],
|
348 |
+
roles=("问", "答"),
|
349 |
+
prompt="问:{query}\n答:",
|
350 |
+
sep="\n",
|
351 |
+
)
|
352 |
+
)
|
353 |
+
|
354 |
+
"""ChatGLM2 template
|
355 |
+
source: https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1007
|
356 |
+
"""
|
357 |
+
register_conv_template(
|
358 |
+
# source:
|
359 |
+
Conversation(
|
360 |
+
name="chatglm2",
|
361 |
+
system_prompt="",
|
362 |
+
messages=[],
|
363 |
+
roles=("问", "答"),
|
364 |
+
prompt="问:{query}\n\n答:",
|
365 |
+
sep="\n\n",
|
366 |
+
)
|
367 |
+
)
|
368 |
+
|
369 |
+
"""Phoenix template"""
|
370 |
+
register_conv_template(
|
371 |
+
Conversation(
|
372 |
+
name="phoenix",
|
373 |
+
system_prompt="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
374 |
+
messages=[],
|
375 |
+
roles=("Human", "Assistant"),
|
376 |
+
prompt="Human: <s>{query}</s>Assistant: ",
|
377 |
+
sep="</s>",
|
378 |
+
)
|
379 |
+
)
|
380 |
+
|
381 |
+
"""belle template
|
382 |
+
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
383 |
+
"""
|
384 |
+
register_conv_template(
|
385 |
+
Conversation(
|
386 |
+
name="belle",
|
387 |
+
system_prompt="",
|
388 |
+
messages=[],
|
389 |
+
roles=("Human", "Belle"),
|
390 |
+
prompt="Human: {query}\n\nBelle: ",
|
391 |
+
sep="\n\n",
|
392 |
+
)
|
393 |
+
)
|
394 |
+
|
395 |
+
"""aquila template
|
396 |
+
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
397 |
+
"""
|
398 |
+
register_conv_template(
|
399 |
+
Conversation(
|
400 |
+
name="aquila",
|
401 |
+
system_prompt="A chat between a curious human and an artificial intelligence assistant. "
|
402 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
403 |
+
messages=[],
|
404 |
+
roles=("Human", "Assistant"),
|
405 |
+
prompt="Human: {query}###Assistant: ",
|
406 |
+
sep="###",
|
407 |
+
)
|
408 |
+
)
|
409 |
+
|
410 |
+
"""intern template
|
411 |
+
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
412 |
+
"""
|
413 |
+
register_conv_template(
|
414 |
+
Conversation(
|
415 |
+
name="intern",
|
416 |
+
system_prompt="",
|
417 |
+
messages=[],
|
418 |
+
roles=("<|User|>", "<|Bot|>"),
|
419 |
+
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
420 |
+
sep="<eoa>\n",
|
421 |
+
stop_str="<eoa>",
|
422 |
+
)
|
423 |
+
)
|
424 |
+
|
425 |
+
"""StarChat template"""
|
426 |
+
register_conv_template(
|
427 |
+
Conversation(
|
428 |
+
name="starchat",
|
429 |
+
system_prompt="<system>\n",
|
430 |
+
messages=[],
|
431 |
+
roles=("<|user|>", "<|assistant|>"),
|
432 |
+
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
433 |
+
sep="<|end|>\n",
|
434 |
+
stop_str="<|end|>",
|
435 |
+
)
|
436 |
+
)
|
437 |
+
|
438 |
+
"""llama2 template
|
439 |
+
reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
440 |
+
"""
|
441 |
+
register_conv_template(
|
442 |
+
Conversation(
|
443 |
+
name="llama2",
|
444 |
+
system_prompt="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
445 |
+
"Always answer as helpfully as possible, while being safe. "
|
446 |
+
"Your answers should not include any harmful, unethical, racist, sexist, "
|
447 |
+
"toxic, dangerous, or illegal content. "
|
448 |
+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
449 |
+
"If a question does not make any sense, or is not factually coherent, "
|
450 |
+
"explain why instead of answering something not correct. "
|
451 |
+
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
452 |
+
messages=[],
|
453 |
+
roles=("[INST]", "[/INST]"),
|
454 |
+
prompt=" [INST] {query} [/INST] ",
|
455 |
+
sep="</s>",
|
456 |
+
)
|
457 |
+
)
|
458 |
+
|
459 |
+
"""llama2-zh template
|
460 |
+
Sources: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
|
461 |
+
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
462 |
+
"""
|
463 |
+
register_conv_template(
|
464 |
+
Conversation(
|
465 |
+
name="llama2-zh",
|
466 |
+
system_prompt="<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n",
|
467 |
+
messages=[],
|
468 |
+
roles=("[INST]", "[/INST]"),
|
469 |
+
prompt=" [INST] {query} [/INST] ",
|
470 |
+
sep="</s>",
|
471 |
+
)
|
472 |
+
)
|
473 |
+
"""XVERSE template
|
474 |
+
Supports: https://huggingface.co/xverse/XVERSE-13B-Chat
|
475 |
+
"""
|
476 |
+
register_conv_template(
|
477 |
+
Conversation(
|
478 |
+
name="xverse",
|
479 |
+
system_prompt="",
|
480 |
+
messages=[],
|
481 |
+
roles=("Human", "Assistant"),
|
482 |
+
prompt="Human: {query}\n\nAssistant: ",
|
483 |
+
sep="</s>",
|
484 |
+
)
|
485 |
+
)
|
486 |
+
|
487 |
+
"""Qwen template
|
488 |
+
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
489 |
+
chatml: https://xbot123.com/645a461b922f176d7cfdbc2d/
|
490 |
+
"""
|
491 |
+
register_conv_template(
|
492 |
+
Conversation(
|
493 |
+
name="chatml",
|
494 |
+
system_prompt="You are a helpful assistant.",
|
495 |
+
messages=[],
|
496 |
+
roles=("user", "assistant"),
|
497 |
+
prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n",
|
498 |
+
sep="<|im_end|>\n",
|
499 |
+
stop_str="<|im_end|>",
|
500 |
+
)
|
501 |
+
)
|
502 |
+
|
503 |
+
|
504 |
+
def get_conv_template(name: str) -> Conversation:
|
505 |
+
"""Get a conversation template."""
|
506 |
+
return conv_templates[name]
|
507 |
+
|
508 |
+
|
509 |
+
class SavePeftModelTrainer(Trainer):
|
510 |
+
"""
|
511 |
+
Trainer for lora models
|
512 |
+
"""
|
513 |
+
|
514 |
+
def save_model(self, output_dir=None, _internal_call=False):
|
515 |
+
"""Save the LoRA model."""
|
516 |
+
os.makedirs(output_dir, exist_ok=True)
|
517 |
+
if self.args.local_rank in [-1, 0]:
|
518 |
+
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
519 |
+
self.model.save_pretrained(output_dir)
|
520 |
+
|
521 |
+
|
522 |
+
def save_model(output_dir, model, tokenizer, args):
|
523 |
+
"""Save the model and the tokenizer."""
|
524 |
+
os.makedirs(output_dir, exist_ok=True)
|
525 |
+
|
526 |
+
# Take care of distributed/parallel training
|
527 |
+
model_to_save = model.module if hasattr(model, "module") else model
|
528 |
+
if args.local_rank in [-1, 0]:
|
529 |
+
model_to_save.save_pretrained(output_dir)
|
530 |
+
tokenizer.save_pretrained(output_dir)
|
531 |
+
torch.save(args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
532 |
+
|
533 |
+
|
534 |
+
def print_trainable_parameters(model):
|
535 |
+
"""
|
536 |
+
Prints the number of trainable parameters in the model.
|
537 |
+
"""
|
538 |
+
trainable_params = 0
|
539 |
+
all_param = 0
|
540 |
+
for _, param in model.named_parameters():
|
541 |
+
all_param += param.numel()
|
542 |
+
if param.requires_grad:
|
543 |
+
trainable_params += param.numel()
|
544 |
+
print(
|
545 |
+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
546 |
+
)
|
547 |
+
|
548 |
+
|
549 |
+
def find_all_linear_names(peft_model, int4=False, int8=False):
|
550 |
+
"""Find all linear layer names in the model. reference from qlora paper."""
|
551 |
+
cls = torch.nn.Linear
|
552 |
+
if int4 or int8:
|
553 |
+
import bitsandbytes as bnb
|
554 |
+
if int4:
|
555 |
+
cls = bnb.nn.Linear4bit
|
556 |
+
elif int8:
|
557 |
+
cls = bnb.nn.Linear8bitLt
|
558 |
+
lora_module_names = set()
|
559 |
+
for name, module in peft_model.named_modules():
|
560 |
+
if isinstance(module, cls):
|
561 |
+
# last layer is not add to lora_module_names
|
562 |
+
if 'lm_head' in name:
|
563 |
+
continue
|
564 |
+
if 'output_layer' in name:
|
565 |
+
continue
|
566 |
+
names = name.split('.')
|
567 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
568 |
+
return sorted(lora_module_names)
|
569 |
+
|
570 |
+
|
571 |
+
def main():
|
572 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, PeftArguments))
|
573 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
574 |
+
|
575 |
+
logger.info(f"Model args: {model_args}")
|
576 |
+
logger.info(f"Data args: {data_args}")
|
577 |
+
logger.info(f"Training args: {training_args}")
|
578 |
+
logger.info(
|
579 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
580 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
581 |
+
)
|
582 |
+
|
583 |
+
# Set seed before initializing model.
|
584 |
+
set_seed(training_args.seed)
|
585 |
+
|
586 |
+
if not model_args.model_type:
|
587 |
+
raise ValueError("Please specify a model_type, e.g. llama, chatglm, bloom, etc.")
|
588 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_args.model_type]
|
589 |
+
|
590 |
+
# Load tokenizer
|
591 |
+
tokenizer_kwargs = {
|
592 |
+
"cache_dir": model_args.cache_dir,
|
593 |
+
"use_fast": model_args.use_fast_tokenizer,
|
594 |
+
"trust_remote_code": model_args.trust_remote_code,
|
595 |
+
}
|
596 |
+
tokenizer_name_or_path = model_args.tokenizer_name_or_path
|
597 |
+
if not tokenizer_name_or_path:
|
598 |
+
tokenizer_name_or_path = model_args.model_name_or_path
|
599 |
+
tokenizer = tokenizer_class.from_pretrained(tokenizer_name_or_path, **tokenizer_kwargs)
|
600 |
+
prompt_template = get_conv_template(data_args.template_name)
|
601 |
+
if tokenizer.eos_token_id is None:
|
602 |
+
tokenizer.eos_token = prompt_template.stop_str # eos token is required for SFT
|
603 |
+
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
604 |
+
if tokenizer.pad_token_id is None:
|
605 |
+
if tokenizer.unk_token_id is not None:
|
606 |
+
tokenizer.pad_token = tokenizer.unk_token
|
607 |
+
else:
|
608 |
+
tokenizer.pad_token = tokenizer.eos_token
|
609 |
+
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
610 |
+
|
611 |
+
logger.debug(f"Tokenizer: {tokenizer}")
|
612 |
+
IGNORE_INDEX = LabelSmoother.ignore_index if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
613 |
+
|
614 |
+
# Get datasets
|
615 |
+
if data_args.dataset_name is not None:
|
616 |
+
# Downloading and loading a dataset from the hub.
|
617 |
+
raw_datasets = load_dataset(
|
618 |
+
data_args.dataset_name,
|
619 |
+
data_args.dataset_config_name,
|
620 |
+
cache_dir=model_args.cache_dir,
|
621 |
+
)
|
622 |
+
if "validation" not in raw_datasets.keys():
|
623 |
+
raw_datasets["validation"] = load_dataset(
|
624 |
+
data_args.dataset_name,
|
625 |
+
data_args.dataset_config_name,
|
626 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
627 |
+
cache_dir=model_args.cache_dir,
|
628 |
+
)
|
629 |
+
raw_datasets["train"] = load_dataset(
|
630 |
+
data_args.dataset_name,
|
631 |
+
data_args.dataset_config_name,
|
632 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
633 |
+
cache_dir=model_args.cache_dir,
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
# Loading a dataset from local files.
|
637 |
+
data_files = {}
|
638 |
+
if data_args.train_file_dir is not None and os.path.exists(data_args.train_file_dir):
|
639 |
+
train_data_files = glob(f'{data_args.train_file_dir}/**/*.json', recursive=True) + glob(
|
640 |
+
f'{data_args.train_file_dir}/**/*.jsonl', recursive=True)
|
641 |
+
logger.info(f"train files: {train_data_files}")
|
642 |
+
data_files["train"] = train_data_files
|
643 |
+
if data_args.validation_file_dir is not None and os.path.exists(data_args.validation_file_dir):
|
644 |
+
eval_data_files = glob(f'{data_args.validation_file_dir}/**/*.json', recursive=True) + glob(
|
645 |
+
f'{data_args.validation_file_dir}/**/*.jsonl', recursive=True)
|
646 |
+
logger.info(f"eval files: {eval_data_files}")
|
647 |
+
data_files["validation"] = eval_data_files
|
648 |
+
raw_datasets = load_dataset(
|
649 |
+
'json',
|
650 |
+
data_files=data_files,
|
651 |
+
cache_dir=model_args.cache_dir,
|
652 |
+
)
|
653 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
654 |
+
if "validation" not in raw_datasets.keys():
|
655 |
+
raw_datasets["validation"] = load_dataset(
|
656 |
+
'json',
|
657 |
+
data_files=data_files,
|
658 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
659 |
+
cache_dir=model_args.cache_dir,
|
660 |
+
)
|
661 |
+
raw_datasets["train"] = load_dataset(
|
662 |
+
'json',
|
663 |
+
data_files=data_files,
|
664 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
665 |
+
cache_dir=model_args.cache_dir,
|
666 |
+
)
|
667 |
+
logger.info(f"Raw datasets: {raw_datasets}")
|
668 |
+
|
669 |
+
# Preprocessing the datasets
|
670 |
+
max_source_length = data_args.max_source_length
|
671 |
+
max_target_length = data_args.max_target_length
|
672 |
+
max_length = max_source_length + max_target_length
|
673 |
+
|
674 |
+
def preprocess_function(examples):
|
675 |
+
"""
|
676 |
+
Preprocessing the datasets.
|
677 |
+
part of code modified from https://github.com/lm-sys/FastChat
|
678 |
+
"""
|
679 |
+
input_ids_list = []
|
680 |
+
targets_list = []
|
681 |
+
roles = ["human", "gpt"]
|
682 |
+
|
683 |
+
def get_dialog(examples):
|
684 |
+
for i, source in enumerate(examples['conversations']):
|
685 |
+
if len(source) < 2:
|
686 |
+
continue
|
687 |
+
data_role = source[0].get("from", "")
|
688 |
+
if data_role not in roles or data_role != roles[0]:
|
689 |
+
# Skip the first one if it is not from human
|
690 |
+
source = source[1:]
|
691 |
+
if len(source) < 2:
|
692 |
+
continue
|
693 |
+
messages = []
|
694 |
+
for j, sentence in enumerate(source):
|
695 |
+
data_role = sentence.get("from", "")
|
696 |
+
if data_role not in roles:
|
697 |
+
logger.warning(f"unknown role: {data_role}, {i}. (ignored)")
|
698 |
+
break
|
699 |
+
if data_role == roles[j % 2]:
|
700 |
+
messages.append(sentence["value"])
|
701 |
+
if len(messages) < 2 or len(messages) % 2 != 0:
|
702 |
+
continue
|
703 |
+
# Convert the list to pairs of elements
|
704 |
+
history_messages = [[messages[k], messages[k + 1]] for k in range(0, len(messages), 2)]
|
705 |
+
yield prompt_template.get_dialog(history_messages)
|
706 |
+
|
707 |
+
for dialog in get_dialog(examples):
|
708 |
+
input_ids, labels = [], []
|
709 |
+
|
710 |
+
for i in range(len(dialog) // 2):
|
711 |
+
source_ids = tokenizer.encode(text=dialog[2 * i], add_special_tokens=(i == 0))
|
712 |
+
target_ids = tokenizer.encode(text=dialog[2 * i + 1], add_special_tokens=False)
|
713 |
+
|
714 |
+
if len(source_ids) > max_source_length:
|
715 |
+
source_ids = source_ids[:max_source_length]
|
716 |
+
if len(target_ids) > max_target_length - 1: # eos token
|
717 |
+
target_ids = target_ids[:max_target_length - 1]
|
718 |
+
if len(source_ids) > 0 and source_ids[0] == tokenizer.eos_token_id:
|
719 |
+
source_ids = source_ids[1:]
|
720 |
+
if len(target_ids) > 0 and target_ids[-1] == tokenizer.eos_token_id:
|
721 |
+
target_ids = target_ids[:-1]
|
722 |
+
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
723 |
+
break
|
724 |
+
|
725 |
+
input_ids += source_ids + target_ids + [tokenizer.eos_token_id] # add eos token for each turn
|
726 |
+
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
727 |
+
|
728 |
+
input_ids_list.append(input_ids)
|
729 |
+
targets_list.append(labels)
|
730 |
+
|
731 |
+
return dict(
|
732 |
+
input_ids=input_ids_list,
|
733 |
+
labels=targets_list,
|
734 |
+
)
|
735 |
+
|
736 |
+
def filter_empty_labels(example):
|
737 |
+
"""Remove empty labels dataset."""
|
738 |
+
return not all(label == IGNORE_INDEX for label in example["labels"])
|
739 |
+
|
740 |
+
train_dataset = None
|
741 |
+
max_train_samples = 0
|
742 |
+
if training_args.do_train:
|
743 |
+
if "train" not in raw_datasets:
|
744 |
+
raise ValueError("--do_train requires a train dataset")
|
745 |
+
train_dataset = raw_datasets['train']
|
746 |
+
max_train_samples = len(train_dataset)
|
747 |
+
if data_args.max_train_samples is not None and data_args.max_train_samples > 0:
|
748 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
749 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
750 |
+
logger.debug(f"Example train_dataset[0]: {train_dataset[0]}")
|
751 |
+
with training_args.main_process_first(desc="Train dataset tokenization"):
|
752 |
+
train_dataset = train_dataset.shuffle().map(
|
753 |
+
preprocess_function,
|
754 |
+
batched=True,
|
755 |
+
num_proc=data_args.preprocessing_num_workers,
|
756 |
+
remove_columns=train_dataset.column_names,
|
757 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
758 |
+
desc="Running tokenizer on dataset",
|
759 |
+
)
|
760 |
+
train_dataset = train_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
|
761 |
+
logger.debug(f"Num train_samples: {len(train_dataset)}")
|
762 |
+
logger.debug("Tokenized training example:")
|
763 |
+
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(train_dataset[0]['input_ids'])}")
|
764 |
+
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id
|
765 |
+
for label in list(train_dataset[0]['labels'])]
|
766 |
+
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
|
767 |
+
|
768 |
+
eval_dataset = None
|
769 |
+
max_eval_samples = 0
|
770 |
+
if training_args.do_eval:
|
771 |
+
with training_args.main_process_first(desc="Eval dataset tokenization"):
|
772 |
+
if "validation" not in raw_datasets:
|
773 |
+
raise ValueError("--do_eval requires a validation dataset")
|
774 |
+
eval_dataset = raw_datasets["validation"]
|
775 |
+
max_eval_samples = len(eval_dataset)
|
776 |
+
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
|
777 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
778 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
779 |
+
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
|
780 |
+
eval_dataset = eval_dataset.map(
|
781 |
+
preprocess_function,
|
782 |
+
batched=True,
|
783 |
+
num_proc=data_args.preprocessing_num_workers,
|
784 |
+
remove_columns=eval_dataset.column_names,
|
785 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
786 |
+
desc="Running tokenizer on dataset",
|
787 |
+
)
|
788 |
+
eval_dataset = eval_dataset.filter(filter_empty_labels, num_proc=data_args.preprocessing_num_workers)
|
789 |
+
logger.debug(f"Num eval_samples: {len(eval_dataset)}")
|
790 |
+
logger.debug("Tokenized eval example:")
|
791 |
+
logger.debug(tokenizer.decode(eval_dataset[0]['input_ids']))
|
792 |
+
|
793 |
+
# Load model
|
794 |
+
if model_args.model_name_or_path:
|
795 |
+
torch_dtype = (
|
796 |
+
model_args.torch_dtype
|
797 |
+
if model_args.torch_dtype in ["auto", None]
|
798 |
+
else getattr(torch, model_args.torch_dtype)
|
799 |
+
)
|
800 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
801 |
+
ddp = world_size != 1
|
802 |
+
if ddp:
|
803 |
+
model_args.device_map = {"": int(os.environ["LOCAL_RANK"]) or 0}
|
804 |
+
if training_args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()):
|
805 |
+
logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
|
806 |
+
config = config_class.from_pretrained(
|
807 |
+
model_args.model_name_or_path,
|
808 |
+
trust_remote_code=model_args.trust_remote_code,
|
809 |
+
torch_dtype=torch_dtype,
|
810 |
+
cache_dir=model_args.cache_dir
|
811 |
+
)
|
812 |
+
model = model_class.from_pretrained(
|
813 |
+
model_args.model_name_or_path,
|
814 |
+
config=config,
|
815 |
+
load_in_8bit=model_args.load_in_8bit,
|
816 |
+
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
817 |
+
device_map=model_args.device_map,
|
818 |
+
trust_remote_code=model_args.trust_remote_code,
|
819 |
+
quantization_config=BitsAndBytesConfig(
|
820 |
+
load_in_4bit=True,
|
821 |
+
bnb_4bit_use_double_quant=True,
|
822 |
+
bnb_4bit_quant_type="nf4",
|
823 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
824 |
+
) if training_args.qlora else None,
|
825 |
+
)
|
826 |
+
if hasattr(model, 'lm_head'):
|
827 |
+
model.lm_head = CastOutputToFloat(model.lm_head)
|
828 |
+
else:
|
829 |
+
raise ValueError(f"Error, model_name_or_path is None, SFT must be loaded from a pre-trained model")
|
830 |
+
|
831 |
+
if training_args.use_peft:
|
832 |
+
logger.info("Fine-tuning method: LoRA(PEFT)")
|
833 |
+
if training_args.peft_path is not None:
|
834 |
+
logger.info(f"Peft from pre-trained model: {training_args.peft_path}")
|
835 |
+
model = PeftModel.from_pretrained(model, training_args.peft_path, is_trainable=True)
|
836 |
+
else:
|
837 |
+
target_modules = training_args.target_modules.split(',') if training_args.target_modules else None
|
838 |
+
if target_modules and 'all' in target_modules:
|
839 |
+
target_modules = find_all_linear_names(model, int4=False, int8=model_args.load_in_8bit)
|
840 |
+
modules_to_save = training_args.modules_to_save
|
841 |
+
if modules_to_save is not None:
|
842 |
+
modules_to_save = modules_to_save.split(',')
|
843 |
+
logger.info(f"Peft target_modules: {target_modules}")
|
844 |
+
logger.info(f"Peft lora_rank: {training_args.lora_rank}")
|
845 |
+
peft_config = LoraConfig(
|
846 |
+
task_type=TaskType.CAUSAL_LM,
|
847 |
+
target_modules=target_modules,
|
848 |
+
inference_mode=False,
|
849 |
+
r=training_args.lora_rank,
|
850 |
+
lora_alpha=training_args.lora_alpha,
|
851 |
+
lora_dropout=training_args.lora_dropout,
|
852 |
+
modules_to_save=modules_to_save)
|
853 |
+
model = get_peft_model(model, peft_config)
|
854 |
+
if model_args.load_in_8bit:
|
855 |
+
model = prepare_model_for_int8_training(model)
|
856 |
+
model.print_trainable_parameters()
|
857 |
+
else:
|
858 |
+
logger.info("Fine-tuning method: Full parameters training")
|
859 |
+
model = model.float()
|
860 |
+
print_trainable_parameters(model)
|
861 |
+
logger.debug(f"Model: {model}")
|
862 |
+
|
863 |
+
# Initialize our Trainer
|
864 |
+
if training_args.gradient_checkpointing:
|
865 |
+
model.gradient_checkpointing_enable()
|
866 |
+
model.config.use_cache = False
|
867 |
+
else:
|
868 |
+
model.config.use_cache = True
|
869 |
+
model.enable_input_require_grads()
|
870 |
+
if not ddp and torch.cuda.device_count() > 1:
|
871 |
+
# Keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
872 |
+
model.is_parallelizable = True
|
873 |
+
model.model_parallel = True
|
874 |
+
|
875 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX)
|
876 |
+
# Initialize our Trainer
|
877 |
+
trainer = SavePeftModelTrainer(
|
878 |
+
model=model,
|
879 |
+
args=training_args,
|
880 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
881 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
882 |
+
tokenizer=tokenizer,
|
883 |
+
data_collator=data_collator,
|
884 |
+
)
|
885 |
+
|
886 |
+
# Training
|
887 |
+
if training_args.do_train:
|
888 |
+
logger.info("*** Train ***")
|
889 |
+
sample = next(iter(trainer.get_train_dataloader()))
|
890 |
+
logger.debug(f"Train dataloader example: {sample}")
|
891 |
+
logger.debug(f"Detail input_ids: {list(sample['input_ids'])[:3]}, \nlabels: {list(sample['labels'])[:3]}")
|
892 |
+
logger.debug(f"Decode input_ids[0]: {tokenizer.decode(sample['input_ids'][0])}")
|
893 |
+
replaced_labels = [label if label != IGNORE_INDEX else tokenizer.pad_token_id for label in sample['labels'][0]]
|
894 |
+
logger.debug(f"Decode labels[0]: {tokenizer.decode(replaced_labels)}")
|
895 |
+
checkpoint = None
|
896 |
+
if training_args.resume_from_checkpoint is not None:
|
897 |
+
checkpoint = training_args.resume_from_checkpoint
|
898 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
899 |
+
|
900 |
+
metrics = train_result.metrics
|
901 |
+
metrics["train_samples"] = max_train_samples
|
902 |
+
logger.debug(f"Training metrics: {metrics}")
|
903 |
+
trainer.log_metrics("train", metrics)
|
904 |
+
trainer.save_metrics("train", metrics)
|
905 |
+
model.config.use_cache = True # enable cache after training
|
906 |
+
trainer.save_state()
|
907 |
+
logger.info(f"Saving model checkpoint to {training_args.output_dir}")
|
908 |
+
save_model(training_args.output_dir, model, tokenizer, training_args)
|
909 |
+
|
910 |
+
# Evaluation
|
911 |
+
if training_args.do_eval and trainer.is_world_process_zero():
|
912 |
+
logger.info("*** Evaluate ***")
|
913 |
+
metrics = trainer.evaluate()
|
914 |
+
|
915 |
+
metrics["eval_samples"] = max_eval_samples
|
916 |
+
try:
|
917 |
+
perplexity = math.exp(metrics["eval_loss"])
|
918 |
+
except OverflowError:
|
919 |
+
perplexity = float("inf")
|
920 |
+
metrics["perplexity"] = perplexity
|
921 |
+
logger.debug(f"Eval metrics: {metrics}")
|
922 |
+
trainer.log_metrics("eval", metrics)
|
923 |
+
trainer.save_metrics("eval", metrics)
|
924 |
+
|
925 |
+
|
926 |
+
if __name__ == "__main__":
|
927 |
+
main()
|