fangshengren commited on
Commit
f4fac26
1 Parent(s): dfbdf47

Upload 59 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/stream_chat.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode/*
2
+ .vscode
3
+ !.vscode/settings.json
4
+ !.vscode/tasks.json
5
+ !.vscode/launch.json
6
+ !.vscode/extensions.json
7
+ *.code-workspace
8
+
9
+ # Local History for Visual Studio Code
10
+ .history/
11
+ .idea/
12
+
13
+ # python cache
14
+ *.pyc
15
+ *.cache
16
+
17
+ logs/*
18
+
19
+ data/*
20
+ !/data/my_train_dataset_3k.parquet
21
+ !/data/my_test_dataset_2k.parquet
22
+ !/data/my_valid_dataset_1k.parquet
23
+
24
+ model_save/*
25
+ !model_save/put_model_files_here
26
+
27
+ wandb/*
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.en.md ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # A Small Chat with Chinese Language Model: ChatLM-Chinese-0.2B
4
+ [中文](./README.md) | English
5
+
6
+ </div>
7
+
8
+ # 1. 👋Introduction
9
+
10
+ Today's large language models tend to have large parameters, and consumer-grade computers are slow to do simple inference, let alone train a model from scratch. The goal of this project is to train a generative language models from scratch, including data cleaning, tokenizer training, model pre-training, SFT instruction fine-tuning, RLHF optimization, etc.
11
+
12
+ ChatLM-mini-Chinese is a small Chinese chat model with only 0.2B (added shared weight is about 210M) parameters. It can be pre-trained on machine with a minimum of 4GB of GPU memory (`batch_size=1`, `fp16` or `bf16`), `float16` loading and inference only require a minimum of 512MB of GPU memory.
13
+
14
+ - Make public all pre-training, SFT instruction fine-tuning, and DPO preference optimization datasets sources.
15
+ - Use the `Huggingface` NLP framework, including `transformers`, `accelerate`, `trl`, `peft`, etc.
16
+ - Self-implemented `trainer`, supporting pre-training and SFT fine-tuning on a single machine with a single card or with multiple cards on a single machine. It supports stopping at any position during training and continuing training at any position.
17
+ - Pre-training: Integrated into end-to-end `Text-to-Text` pre-training, non-`mask` mask prediction pre-training.
18
+ - Open source all data cleaning (such as standardization, document deduplication based on mini_hash, etc.), data set construction, data set loading optimization and other processes;
19
+ - tokenizer multi-process word frequency statistics, supports tokenizer training of `sentencepiece` and `huggingface tokenizers`;
20
+ - Pre-training supports checkpoint at any step, and training can be continued from the breakpoint;
21
+ - Streaming loading of large datasets (GB level), supporting buffer data shuffling, does not use memory or hard disk as cache, effectively reducing memory and disk usage. configuring `batch_size=1, max_len=320`, supporting pre-training on a machine with at least 16GB RAM + 4GB GPU memory;
22
+ - Training log record.
23
+ - SFT fine-tuning: open source SFT dataset and data processing process.
24
+ - The self-implemented `trainer` supports prompt command fine-tuning and supports any breakpoint to continue training;
25
+ - Support `sequence to sequence` fine-tuning of `Huggingface trainer`;
26
+ - Supports traditional low learning rate and only trains fine-tuning of the decoder layer.
27
+ - RLHF Preference optimization: Use DPO to optimize all preferences.
28
+ - Support using `peft lora` for preference optimization;
29
+ - Supports model merging, `Lora adapter` can be merged into the original model.
30
+ - Support downstream task fine-tuning: [finetune_examples](./finetune_examples/info_extract/) gives a fine-tuning example of the **Triple Information Extraction Task**. The model dialogue capability after fine-tuning is still there.
31
+
32
+ If you need to do retrieval augmented generation (RAG) based on small models, you can refer to my other project [Phi2-mini-Chinese](https://github.com/charent/Phi2-mini-Chinese). For the code, see [rag_with_langchain.ipynb](https://github.com/charent/Phi2-mini-Chinese/blob/main/rag_with_langchain.ipynb)
33
+
34
+ 🟢**Latest Update**
35
+
36
+ <details open>
37
+ <summary> <b>2024-01-30</b> </summary>
38
+ - The model files are updated to Moda modelscope and can be quickly downloaded through `snapshot_download`. <br/>
39
+ </details>
40
+
41
+ <details close>
42
+ <summary> <b>2024-01-07</b> </summary>
43
+ - Add document deduplication based on mini hash during the data cleaning process (in this project, it's to deduplicated the rows of datasets actually). Prevent the model from spitting out training data during inference after encountering multiple repeated data. <br/>
44
+ - Add the `DropDatasetDuplicate` class to implement deduplication of documents from large data sets. <br/>
45
+ </details>
46
+
47
+ <details close>
48
+ <summary> <b>2023-12-29</b> </summary>
49
+ - Update the model code (weights is NOT changed), you can directly use `AutoModelForSeq2SeqLM.from_pretrained(...)` to load the model for using. <br/>
50
+ - Updated readme documentation. <br/>
51
+ </details>
52
+
53
+ <details close>
54
+ <summary> <b>2023-12-18</b> </summary>
55
+ - Supplementary use of the `ChatLM-mini-0.2B` model to fine-tune the downstream triplet information extraction task code and display the extraction results. <br/>
56
+ - Updated readme documentation. <br/>
57
+ </details>
58
+
59
+ <details close>
60
+ <summary> <b>2023-12-14</b> </summary>
61
+ - Updated model weight files after SFT and DPO. <br/>
62
+ - Updated pre-training, SFT and DPO scripts. <br/>
63
+ - update `tokenizer` to `PreTrainedTokenizerFast`. <br/>
64
+ - Refactor the `dataset` code to support dynamic maximum length. The maximum length of each batch is determined by the longest text in the batch, saving GPU memory. <br/>
65
+ - Added `tokenizer` training details. <br/>
66
+ </details>
67
+
68
+ <details close>
69
+ <summary> <b>2023-12-04</b> </summary>
70
+ - Updated `generate` parameters and model effect display. <br/>
71
+ - Updated readme documentation. <br/>
72
+ </details>
73
+
74
+ <details close>
75
+ <summary> <b>2023-11-28</b> </summary>
76
+ - Updated dpo training code and model weights. <br/>
77
+ </details>
78
+
79
+ <details close>
80
+ <summary> <b>2023-10-19</b> </summary>
81
+ - The project is open source and the model weights are open for download. <br/>
82
+ </details>
83
+
84
+ # 2. 🛠️ChatLM-0.2B-Chinese model training process
85
+ ## 2.1 Pre-training dataset
86
+ All datasets come from the **Single Round Conversation** dataset published on the Internet. After data cleaning and formatting, they are saved as parquet files. For the data processing process, see `utils/raw_data_process.py`. Main datasets include:
87
+
88
+ 1. Community Q&A json version webtext2019zh-large-scale high-quality dataset, see: [nlp_chinese_corpus](https://github.com/brightmart/nlp_chinese_corpus). A total of 4.1 million, with 2.6 million remaining after cleaning.
89
+ 2. baike_qa2019 encyclopedia Q&A, see: <https://aistudio.baidu.com/datasetdetail/107726>, a total of 1.4 million, and the remaining 1.3 million after waking up.
90
+ 3. Chinese medical field question and answer dataset, see: [Chinese-medical-dialogue-data](https://github.com/Toyhom/Chinese-medical-dialogue-data), with a total of 790,000, and the remaining 790,000 after cleaning.
91
+ 4. ~~Financial industry question and answer data, see: <https://zhuanlan.zhihu.com/p/609821974>, a total of 770,000, and the remaining 520,000 after cleaning. ~~**The data quality is too poor and not used. **
92
+ 5. Zhihu question and answer data, see: [Zhihu-KOL](https://huggingface.co/datasets/wangrui6/Zhihu-KOL), with a total of 1 million rows, and 970,000 rows remain after cleaning.
93
+ 6. belle open source instruction training data, introduction: [BELLE](https://github.com/LianjiaTech/BELLE), download: [BelleGroup](https://huggingface.co/BelleGroup), only select `Belle_open_source_1M` , `train_2M_CN`, and `train_3.5M_CN` contain some data with short answers, no complex table structure, and translation tasks (no English vocabulary list), totaling 3.7 million rows, and 3.38 million rows remain after cleaning.
94
+ 7. Wikipedia entry data, piece together the entries into prompts, the first `N` words of the encyclopedia are the answers, use the encyclopedia data of `202309`, and after cleaning, the remaining 1.19 million entry prompts and answers . Wiki download: [zhwiki](https://dumps.wikimedia.org/zhwiki/), convert the downloaded bz2 file to wiki.txt reference: [WikiExtractor](https://github.com/apertium/WikiExtractor).
95
+
96
+ The total number of datasets is 10.23 million: Text-to-Text pre-training set: 9.3 million, evaluation set: 25,000 (because the decoding is slow, the evaluation set is not set too large). ~~Test set: 900,000~~
97
+ SFT fine-tuning and DPO optimization datasets are shown below.
98
+
99
+ ## 2.2 Model
100
+ T5 model (Text-to-Text Transfer Transformer), for details, see the paper: [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683).
101
+
102
+ The model source code comes from huggingface, see: [T5ForConditionalGeneration](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1557).
103
+
104
+ For model configuration, see [model_config.json](https://huggingface.co/charent/ChatLM-mini-Chinese/blob/main/config.json). The official `T5-base`: `encoder layer` and `decoder layer` are both 12 layers. In this project, these two parameters are modified to 10 layers.
105
+
106
+ Model parameters: 0.2B. Word list size: 29298, including only Chinese and a small amount of English.
107
+
108
+ ## 2.3 Training process
109
+ hardware:
110
+ ```bash
111
+ # Pre-training phase:
112
+ CPU: 28 vCPU Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz
113
+ Memory: 60 GB
114
+ GPU: RTX A5000 (24GB) * 2
115
+
116
+ # sft and dpo stages:
117
+ CPU: Intel(R) i5-13600k @ 5.1GHz
118
+ Memory: 32 GB
119
+ GPU: NVIDIA GeForce RTX 4060 Ti 16GB * 1
120
+ ```
121
+
122
+ 1. **tokenizer training**: The existing `tokenizer` training library has OOM problems when encountering large corpus. Therefore, the full corpus is merged and constructed according to word frequency according to a method similar to `BPE`, and the operation takes half a day.
123
+
124
+ 2. **Text-to-Text pre-training**: The learning rate is a dynamic learning rate from `1e-4` to `5e-3`, and the pre-training time is 8 days. Training loss:
125
+ ![traing loss](img/train_loss.png)
126
+
127
+ 3. **prompt supervised fine-tuning (SFT)**: Use the `belle` instruction training dataset (both instruction and answer lengths are below 512), with a dynamic learning rate from `1e-7` to `5e-5` , the fine-tuning time is 2 days. Fine-tuning loss:
128
+ ![finetune loss](img/sft_loss.png)
129
+
130
+ 4. **dpo direct preference optimization(RLHF)**: dataset [alpaca-gpt4-data-zh](https://huggingface.co/datasets/c-s-ale/alpaca-gpt4-data-zh) as `chosen` text , in step `2`, the SFT model performs batch `generate` on the prompts in the dataset, and obtains the `rejected` text, which takes 1 day, dpo full preference optimization, learning rate `le-5`, half precision `fp16`, total `2` `epoch`, taking 3h. dpo loss:
131
+ ![dpo loss](img/dpo_loss.png)
132
+
133
+ ## 2.4 chat show
134
+ ### 2.4.1 stream chat
135
+ By default, `TextIteratorStreamer` of `huggingface transformers` is used to implement streaming dialogue, and only `greedy search` is supported. If you need `beam sample` and other generation methods, please change the `stream_chat` parameter of `cli_demo.py` to `False` .
136
+ ![](./img/stream_chat.gif)
137
+
138
+ ### 2.4.2 Dialogue show
139
+ ![](./img/show1.png)
140
+
141
+ There are problems: the pre-training dataset only has more than 9 million, and the model parameters are only 0.2B. It cannot cover all aspects, and there will be situations where the answer is wrong and the generator is nonsense.
142
+
143
+ # 3. 📑Instructions for using
144
+ ## 3.1 Quick start:
145
+ ```python
146
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
147
+ import torch
148
+
149
+ model_id = 'charent/ChatLM-mini-Chinese'
150
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
151
+
152
+ # 如果无法连接huggingface,打开以下两行代码的注释,将从modelscope下载模型文件,模型文件保存到'./model_save'目录
153
+ # from modelscope import snapshot_download
154
+ # model_id = snapshot_download(model_id, cache_dir='./model_save')
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
157
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=True).to(device)
158
+
159
+ txt = '如何评价Apple这家公司?'
160
+
161
+ encode_ids = tokenizer([txt])
162
+ input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])
163
+
164
+ outs = model.my_generate(
165
+ input_ids=input_ids.to(device),
166
+ attention_mask=attention_mask.to(device),
167
+ max_seq_len=256,
168
+ search_type='beam',
169
+ )
170
+
171
+ outs_txt = tokenizer.batch_decode(outs.cpu().numpy(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
172
+ print(outs_txt[0])
173
+ ```
174
+ ```txt
175
+ Apple是一家专注于设计和用户体验的公司,其产品在设计上注重简约、流畅和功能性,而在用户体验方面则注重用户的反馈和使用体验。作为一家领先的科技公司,苹果公司一直致力于为用户提供最优质的产品和服务,不断推陈出新,不断创新和改进,以满足不断变化的市场需求。
176
+ 在iPhone、iPad和Mac等产品上,苹果公司一直保持着创新的态度,不断推出新的功能和设计,为用户提供更好的使用体验。在iPad上推出的iPad Pro和iPod touch等产品,也一直保持着优秀的用户体验。
177
+ 此外,苹果公司还致力于开发和销售软件和服务,例如iTunes、iCloud和App Store等,这些产品在市场上也获得了广泛的认可和好评。
178
+ 总的来说,苹果公司在设计、用户体验和产品创新方面都做得非常出色,为用户带来了许多便利和惊喜。
179
+
180
+ ```
181
+
182
+ ## 3.2 from clone code repository start
183
+ > [!CAUTION]
184
+ > The model of this project is the `TextToText` model. In the `prompt`, `response` and other fields of the pre-training stage, SFT stage, and RLFH stage, please be sure to add the `[EOS]` end-of-sequence mark.
185
+
186
+ ### 3.2.1 Clone repository
187
+ ```bash
188
+ git clone --depth 1 https://github.com/charent/ChatLM-mini-Chinese.git
189
+
190
+ cd ChatLM-mini-Chinese
191
+ ```
192
+ ### 3.2.2 Install dependencies
193
+ It is recommended to use `python 3.10` for this project. Older python versions may not be compatible with the third-party libraries it depends on.
194
+
195
+ pip installation:
196
+ ```bash
197
+ pip install -r ./requirements.txt
198
+ ```
199
+
200
+ If pip installed the CPU version of pytorch, you can install the CUDA version of pytorch with the following command:
201
+ ```bash
202
+ # pip install torch + cu118
203
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu118
204
+ ```
205
+
206
+ conda installation:
207
+ ```bash
208
+ conda install --yes --file ./requirements.txt
209
+ ```
210
+
211
+ ### 3.2.3 Download the pre-trained model and model configuration file
212
+
213
+ Download model weights and configuration files from `Hugging Face Hub` with `git` command, you need to install [Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large -file-storage), then run:
214
+
215
+ ```bash
216
+ # Use the git command to download the huggingface model. Install [Git LFS] first, otherwise the downloaded model file will not be available.
217
+ git clone --depth 1 https://huggingface.co/charent/ChatLM-mini-Chinese
218
+
219
+ # If unable to connect huggingface, please download from modelscope
220
+ git clone --depth 1 https://www.modelscope.cn/charent/ChatLM-mini-Chinese.git
221
+
222
+ mv ChatLM-mini-Chinese model_save
223
+ ```
224
+
225
+ You can also manually download it directly from the `Hugging Face Hub` warehouse [ChatLM-mini-Chinese](https://huggingface.co/charent/ChatLM-mini-Chinese) and move the downloaded file to the `model_save` directory. .
226
+
227
+
228
+ ## 3.3 Tokenizer training
229
+
230
+ 1. Prepare txt corpus
231
+
232
+ The corpus requirements should be as complete as possible. It is recommended to add multiple corpora, such as encyclopedias, codes, papers, blogs, conversations, etc.
233
+
234
+ This project is mainly based on wiki Chinese encyclopedia. How to obtain Chinese wiki corpus: Chinese Wiki download address: [zhwiki](https://dumps.wikimedia.org/zhwiki/), download the `zhwiki-[archive date]-pages-articles-multistream.xml.bz2` file, About 2.7GB, convert the downloaded bz2 file to wiki.txt reference: [WikiExtractor](https://github.com/apertium/WikiExtractor), then use python's `OpenCC` library to convert to Simplified Chinese, and finally get the Just put `wiki.simple.txt` in the `data` directory of the project root directory. Please merge multiple corpora into one `txt` file yourself.
235
+
236
+ Since training tokenizer consumes a lot of memory, if your corpus is very large (the merged `txt` file exceeds 2G), it is recommended to sample the corpus according to categories and proportions to reduce training time and memory consumption. Training a 1.7GB `txt` file requires about 48GB of memory (estimated, I only have 32GB, triggering swap frequently, computer stuck for a long time T_T), 13600k CPU takes about 1 hour.
237
+
238
+ 2. train tokenizer
239
+
240
+ The difference between `char level` and `byte level` is as follows (Please search for information on your own for specific differences in use.). The tokenizer of `char level` is trained by default. If `byte level` is required, just set `token_type='byte'` in `train_tokenizer.py`.
241
+
242
+ ```python
243
+ # original text
244
+ txt = '这是一段中英混输的句子, (chinese and English, here are words.)'
245
+
246
+ tokens = charlevel_tokenizer.tokenize(txt)
247
+ print(tokens)
248
+ # char level tokens output
249
+ # ['▁这是', '一段', '中英', '混', '输', '的', '句子', '▁,', '▁(', '▁ch', 'inese', '▁and', '▁Eng', 'lish', '▁,', '▁h', 'ere', '▁', 'are', '▁w', 'ord', 's', '▁.', '▁)']
250
+
251
+ tokens = bytelevel_tokenizer.tokenize(txt)
252
+ print(tokens)
253
+ # byte level tokens output
254
+ # ['Ġè¿Ļæĺ¯', 'ä¸Ģ段', 'ä¸Ńèĭ±', 'æ··', 'è¾ĵ', 'çļĦ', 'åı¥åŃIJ', 'Ġ,', 'Ġ(', 'Ġch', 'inese', 'Ġand', 'ĠEng', 'lish', 'Ġ,', 'Ġh', 'ere', 'Ġare', 'Ġw', 'ord', 's', 'Ġ.', 'Ġ)']
255
+ ```
256
+
257
+ Start training:
258
+
259
+ ```python
260
+ # Make sure your training corpus `txt` file is in the data directory
261
+ python train_tokenizer.py
262
+ ```
263
+
264
+ ## 3.4 Text-to-Text pre-training
265
+ 1. Pre-training dataset example
266
+ ```json
267
+ {
268
+ "prompt": "对于花园街,你有什么了解或看法吗?",
269
+ "response": "花园街(是香港油尖旺区的一条富有特色的街道,位于九龙旺角东部,北至界限街,南至登打士街,与通菜街及洗衣街等街道平行。现时这条街道是香港著名的购物区之一。位于亚皆老街以南的一段花园街,也就是\"波鞋街\"整条街约150米长,有50多间售卖运动鞋和运动用品的店舖。旺角道至太子道西一段则为排档区,售卖成衣、蔬菜和水果等。花园街一共分成三段。明清时代,花园街是芒角村栽种花卉的地方。此外,根据历史专家郑宝鸿的考证:花园街曾是1910年代东方殷琴拿烟厂的花园。纵火案。自2005年起,花园街一带最少发生5宗纵火案,当中4宗涉及排档起火。2010年。2010年12月6日,花园街222号一个卖鞋的排档于凌晨5时许首先起火,浓烟涌往旁边住宅大厦,消防接报4"
270
+ }
271
+ ```
272
+
273
+ 2. jupyter-lab or jupyter notebook:
274
+
275
+ See the file `train.ipynb`. It is recommended to use jupyter-lab to avoid considering the situation where the terminal process is killed after disconnecting from the server.
276
+
277
+ 3. Console:
278
+
279
+ Console training needs to consider that the process will be killed after the connection is disconnected. It is recommended to use the process daemon tool `Supervisor` or `screen` to establish a connection session.
280
+
281
+ First, configure `accelerate`, execute the following command, and select according to the prompts. Refer to `accelerate.yaml`, *Note: DeepSpeed installation in Windows is more troublesome*.
282
+ ```bash
283
+ accelerate config
284
+ ```
285
+
286
+ Start training. If you want to use the configuration provided by the project, please add the parameter `--config_file ./accelerate.yaml` after the following command `accelerate launch`. *This configuration is based on the single-machine 2xGPU configuration.*
287
+
288
+ *There are two scripts for pre-training. The trainer implemented in this project corresponds to `train.py`, and the trainer implemented by huggingface corresponds to `pre_train.py`. You can use either one and the effect will be the same. The training information display of the trainer implemented in this project is more beautiful, and it is easier to modify the training details (such as loss function, log records, etc.). All support checkpoint to continue training. The trainer implemented in this project supports continuing training after a breakpoint at any position. Press ` ctrl+c` will save the breakpoint information when exiting the script.*
289
+
290
+ Single machine and single card:
291
+ ```bash
292
+ # The trainer implemented in this project
293
+ accelerate launch ./train.py train
294
+
295
+ # Or use huggingface trainer
296
+ accelerate launch --multi_gpu --num_processes 2 pre_train.py
297
+ ```
298
+
299
+ Single machine with multiple GPUs:
300
+ '2' is the number of gpus, please modify it according to your actual situation.
301
+ ```bash
302
+ # The trainer implemented in this project
303
+ accelerate launch --multi_gpu --num_processes 2 ./train.py train
304
+
305
+ # Or use huggingface trainer
306
+ accelerate launch --multi_gpu --num_processes 2 pre_train.py
307
+ ```
308
+
309
+ Continue training from the breakpoint:
310
+ ```bash
311
+ # The trainer implemented in this project
312
+ accelerate launch --multi_gpu --num_processes 2 ./train.py train --is_keep_training=True
313
+
314
+ # Or use huggingface trainer
315
+ # You need to add `resume_from_checkpoint=True` to the `train` function in `pre_train.py`
316
+ python pre_train.py
317
+ ```
318
+
319
+ ## 3.5 Supervised Fine-tuning, SFT
320
+
321
+ The SFT dataset all comes from the contribution of [BELLE](https://github.com/LianjiaTech/BELLE). Thank you. The SFT datasets are: [generated_chat_0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M), [train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN ) and [train_2M_CN](https://huggingface.co/datasets/BelleGroup/train_2M_CN), about 1.37 million rows remain after cleaning.
322
+ Example of fine-tuning dataset with sft command:
323
+
324
+ ```json
325
+ {
326
+ "prompt": "解释什么是欧洲启示录",
327
+ "response": "欧洲启示录(The Book of Revelation)是新约圣经的最后一卷书,也被称为《启示录》、《默示录》或《约翰默示录》。这本书从宗教的角度描述了世界末日的来临,以及上帝对世界的审判和拯救。 书中的主题包括来临的基督的荣耀,上帝对人性的惩罚和拯救,以及魔鬼和邪恶力量的存在。欧洲启示录是一个充满象征和暗示的文本,对于解读和理解有许多不同的方法和观点。"
328
+ }
329
+ ```
330
+ Make your own dataset by referring to the sample `parquet` file in the `data` directory. The dataset format is: the `parquet` file is divided into two columns, one column of `prompt` text, representing the prompt, and one column of `response` text, representing the expected model. output.
331
+ For fine-tuning details, see the `train` method under `model/trainer.py`. When `is_finetune` is set to `True`, fine-tuning will be performed. Fine-tuning will freeze the embedding layer and encoder layer by default, and only train the decoder layer. If you need to freeze other parameters, please adjust the code yourself.
332
+
333
+ Run SFT fine-tuning:
334
+ ```bash
335
+ # For the trainer implemented in this project, just add the parameter `--is_finetune=True`. The parameter `--is_keep_training=True` can continue training from any breakpoint.
336
+ accelerate launch --multi_gpu --num_processes 2 ./train.py --is_finetune=True
337
+
338
+ # Or use huggingface trainer
339
+ python sft_train.py
340
+ ```
341
+
342
+ ## 3.6 RLHF (Reinforcement Learning Human Feedback Optimization Method)
343
+
344
+ Here are two common preferred methods: PPO and DPO. Please search papers and blogs for specific implementations.
345
+
346
+ 1. PPO method (approximate preference optimization, Proximal Policy Optimization)
347
+ Step 1: Use the fine-tuning dataset to do supervised fine-tuning (SFT, Supervised Finetuning).
348
+ Step 2: Use the preference dataset (a prompt contains at least 2 responses, one wanted response and one unwanted response. Multiple responses can be sorted by score, with the most wanted one having the highest score) to train the reward model (RM, Reward Model). You can use the `peft` library to quickly build the Lora reward model.
349
+ Step 3: Use RM to perform supervised PPO training on the SFT model so that the model meets preferences.
350
+
351
+ 2. Use DPO (Direct Preference Optimization) fine-tuning (**This project uses the DPO fine-tuning method, which saves GPU memory**)
352
+ On the basis of obtaining the SFT model, there is no need to train the reward model, and fine-tuning can be started by obtaining the positive answer (chosen) and the negative answer (rejected). The fine-tuned `chosen` text comes from the original dataset [alpaca-gpt4-data-zh](https://huggingface.co/datasets/c-s-ale/alpaca-gpt4-data-zh), and the rejected text `rejected` comes from SFT Model output after fine-tuning 1 epoch, two other datasets: [huozi_rlhf_data_json](https://huggingface.co/datasets/Skepsun/huozi_rlhf_data_json) and [rlhf-reward-single-round-trans_chinese](https:// huggingface.co/datasets/beyond/rlhf-reward-single-round-trans_chinese), a total of 80,000 dpo data after the merger.
353
+
354
+ For the dpo dataset processing process, see `utils/dpo_data_process.py`.
355
+
356
+ DPO preference optimization dataset example:
357
+ ```json
358
+ {
359
+ "prompt": "为给定的产品创建一个创意标语。,输入:可重复使用的水瓶。",
360
+ "chosen": "\"保护地球,从拥有可重复使用的水瓶开始!\"",
361
+ "rejected": "\"让你的水瓶成为你的生活伴侣,使用可重复使用的水瓶,让你的水瓶成为你的伙伴\""
362
+ }
363
+ ```
364
+ Run preference optimization:
365
+ ```bash
366
+ pythondpo_train.py
367
+ ```
368
+
369
+ ## 3.7 Infering
370
+ Make sure there are the following files in the `model_save` directory, These files can be found in the `Hugging Face Hub` repository [ChatLM-Chinese-0.2B](https://huggingface.co/charent/ChatLM-mini-Chinese)::
371
+ ```bash
372
+ ChatLM-mini-Chinese
373
+ ├─model_save
374
+ | ├─config.json
375
+ | ├─configuration_chat_model.py
376
+ | ���─generation_config.json
377
+ | ├─model.safetensors
378
+ | ├─modeling_chat_model.py
379
+ | ├─special_tokens_map.json
380
+ | ├─tokenizer.json
381
+ | └─tokenizer_config.json
382
+ ```
383
+
384
+ 1. Console run:
385
+ ```bash
386
+ python cli_demo.py
387
+ ```
388
+
389
+ 2. API call
390
+ ```bash
391
+ python api_demo.py
392
+ ```
393
+
394
+ API call example:
395
+ API调用示例:
396
+ ```bash
397
+ curl --location '127.0.0.1:8812/api/chat' \
398
+ --header 'Content-Type: application/json' \
399
+ --header 'Authorization: Bearer Bearer' \
400
+ --data '{
401
+ "input_txt": "感冒了要怎么办"
402
+ }'
403
+ ```
404
+ ![api demo](./img/api_example.png)
405
+
406
+ ## 3.8 Fine-tuning of downstream tasks
407
+
408
+ Here we take the triplet information in the text as an example to do downstream fine-tuning. Traditional deep learning extraction methods for this task can be found in the repository [pytorch_IE_model](https://github.com/charent/pytorch_IE_model). Extract all the triples in a piece of text, such as the sentence `"Sketching Essays" is a book published by Metallurgical Industry in 2006, the author is Zhang Lailiang`, extract the triples `(Sketching Essays, author, Zhang Lailiang)` and `( Sketching essays, publishing house, metallurgical industry)`.
409
+
410
+ The original dataset is: [Baidu Triplet Extraction dataset](https://aistudio.baidu.com/datasetdetail/11384). Example of the processed fine-tuned dataset format:
411
+ ```json
412
+ {
413
+ "prompt": "请抽取出给定句子中的所有三元组。给定句子:《家乡的月亮》是宋雪莱演唱的一首歌曲,所属专辑是《久违的哥们》",
414
+ "response": "[(家乡的月亮,歌手,宋雪莱),(家乡的月亮,所属专辑,久违的哥们)]"
415
+ }
416
+ ```
417
+
418
+ You can directly use the `sft_train.py` script for fine-tuning. The script [finetune_IE_task.ipynb](./finetune_examples/info_extract/finetune_IE_task.ipynb) contains the detailed decoding process. The training dataset is about `17000`, the learning rate is `5e-5`, and the training epoch is `5`. The dialogue capabilities of other tasks have not disappeared after fine-tuning.
419
+
420
+ ![Conversation ability after fine-tuning of information extraction task](./img/ie_task_chat.png)
421
+
422
+ Fine-tuning effects:
423
+ The public `dev` dataset of `Baidu triple extraction dataset` is used as a test set to compare with the traditional method [pytorch_IE_model](https://github.com/charent/pytorch_IE_model).
424
+
425
+ | Model | F1 score | Precision | Recall |
426
+ | :--- | :----: | :---: | :---: |
427
+ | ChatLM-Chinese-0.2B fine-tuning | 0.74 | 0.75 | 0.73 |
428
+ | ChatLM-Chinese-0.2B without pre-training | 0.51 | 0.53 | 0.49 |
429
+ | Traditional deep learning method | 0.80 | 0.79 | 80.1 |
430
+
431
+ Note: `ChatLM-Chinese-0.2B without pre-training` means directly initializing random parameters, starting training, learning rate `1e-4`, and other parameters are consistent with fine-tuning.
432
+
433
+ ## 3.9 C-Eval score
434
+ The model itself is not trained with a large dataset and it is no fine-tuning for the instructions for answering multiple-choice questions, and the C-Eval score is basically at the baseline level. If necessary, it can be used as a reference. The C-Eval review code can be found at: 'eval/c_eavl.ipynb'
435
+
436
+ | category | correct | question_count| accuracy |
437
+ | :--- | :----: | :---: | :---: |
438
+ | Humanities | 63 | 257 | 24.51% |
439
+ | Other | 89 | 384 | 23.18% |
440
+ | STEM | 89 | 430 | 20.70% |
441
+ | Social Science | 72 | 275 | 26.18% |
442
+
443
+ # 4. 🎓Citation
444
+ If you think this project is helpful to you, please site it.
445
+ ```conf
446
+ @misc{Charent2023,
447
+ author={Charent Chen},
448
+ title={A small chinese chat language model with 0.2B parameters base on T5},
449
+ year={2023},
450
+ publisher = {GitHub},
451
+ journal = {GitHub repository},
452
+ howpublished = {\url{https://github.com/charent/ChatLM-mini-Chinese}},
453
+ }
454
+ ```
455
+
456
+ # 5. 🤔Other matters
457
+ This project does not bear any risks and responsibilities arising from data security and public opinion risks caused by open source models and codes, or any model being misled, abused, disseminated, or improperly exploited.
README.md CHANGED
@@ -1,12 +1,474 @@
1
- ---
2
- title: ChatmlTest
3
- emoji: 🐨
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.26.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 中文对话0.2B小模型 ChatLM-Chinese-0.2B
4
+
5
+ 中文 | [English](./README.en.md)
6
+
7
+ </div>
8
+
9
+
10
+ # 一、👋介绍
11
+ 现在的大语言模型的参数往往较大,消费级电脑单纯做推理都比较慢,更别说想自己从头开始训练一个模型了。本项目的目标是从0开始训练一个生成式语言模型,包括数据清洗、tokenizer训练、模型预训练、SFT指令微调、RLHF优化等。
12
+
13
+ ChatLM-mini-Chinese为中文对话小模型,模型参数只有0.2B(算共享权重约210M),可以在最低4GB显存的机器进行预训练(`batch_size=1`,`fp16`或者` bf16`),`float16`加载、推理最少只需要512MB显存。
14
+
15
+
16
+ - 公开所有预训练、SFT指令微调、DPO偏好优化数据集来源。
17
+ - 使用`Huggingface`NLP框架,包括`transformers`、`accelerate`、`trl`、`peft`等。
18
+ - 自实现`trainer`,支持单机单卡、单机多卡进行预训练、SFT微调。训练过程中支持在任意位置停止,及在任意位置继续训练。
19
+ - 预训练:整合为端到端的`Text-to-Text`预训练,非`mask`掩码预测预训练。
20
+ - 开源所有数据清洗(如规范化、基于mini_hash的文档去重等)、数据集构造、数据集加载优化等流程;
21
+ - tokenizer多进程词频统计,支持`sentencepiece`、`huggingface tokenizers`的tokenizer训练;
22
+ - 预训练支持任意位置断点,可从断点处继续训练;
23
+ - 大数据集(GB级别)流式加载、支持缓冲区数据打乱,不利用内存、硬盘作为缓存,有效减少内存、磁盘占用。配置`batch_size=1, max_len=320`下,最低支持在16GB内存+4GB显存的机器上进行预训练;
24
+ - 训练日志记录。
25
+ - SFT微调:开源SFT数据集及数据处理过程。
26
+ - 自实现`trainer`支持prompt指令微调, 支持任意断点继续训练;
27
+ - 支持`Huggingface trainer`的`sequence to sequence`微调;
28
+ - 支持传统的低学习率,只训练decoder层的微调。
29
+ - RLHF偏好优化:使用DPO进行全量偏好优化。
30
+ - 支持使用`peft lora`进行偏好优化;
31
+ - 支持模型合并,可将`Lora adapter`合并到原始模型中。
32
+ - 支持下游任务微调:[finetune_examples](./finetune_examples/info_extract/)给出**三元组信息抽取任务**的微调示例,微调后的模型对话能力仍在。
33
+
34
+ 如果需要做基于小模型的检索增强生成(RAG),可以参考我的另一个项目[Phi2-mini-Chinese](https://github.com/charent/Phi2-mini-Chinese),代码见[rag_with_langchain.ipynb](https://github.com/charent/Phi2-mini-Chinese/blob/main/rag_with_langchain.ipynb)
35
+
36
+ 🟢**最近更新**
37
+
38
+ <details open>
39
+ <summary> <b>2024-01-30</b> </summary>
40
+ - 模型文件更新到魔搭modelscope,可以通过`snapshot_download`快速下载。<br/>
41
+ </details>
42
+
43
+ <details close>
44
+ <summary> <b>2024-01-07</b> </summary>
45
+ - 添加数据清洗过程中基于mini hash实现的文档去重(在本项目中其实是数据集的样本去重),防止模型遇到多次重复数据后,在推理时吐出训练数据。<br/>
46
+ - 添加`DropDatasetDuplicate`类实现对大数据集的文档去重。<br/>
47
+ </details>
48
+
49
+ <details close>
50
+ <summary> <b>2023-12-29</b> </summary>
51
+ - 更新模型代码(权重不变),可以直接使用`AutoModelForSeq2SeqLM.from_pretrained(...)`加载模型使用。<br/>
52
+ - 更新readme文档。<br/>
53
+ </details>
54
+
55
+ <details close>
56
+ <summary> <b>2023-12-18</b> </summary>
57
+ - 补充利用`ChatLM-mini-0.2B`模型微调下游三元组信息抽取任务代码及抽取效果展示 。<br/>
58
+ - 更新readme文档。<br/>
59
+ </details>
60
+
61
+ <details close>
62
+ <summary> <b>2023-12-14</b> </summary>
63
+ - 更新SFT、DPO后的模型权重文件。 <br/>
64
+ - 更新预训练、SFT及DPO脚本。 <br/>
65
+ - 更新`tokenizer`为`PreTrainedTokenizerFast`。 <br/>
66
+ - 重构`dataset`代码,支持动态最大长度,每个批次的最大长度由该批次的最长文本决定,节省显存。 <br/>
67
+ - 补充`tokenizer`训练细节。 <br/>
68
+ </details>
69
+
70
+ <details close>
71
+ <summary> <b>2023-12-04</b> </summary>
72
+ - 更新`generate`参数及模型效果展示。<br/>
73
+ - 更新readme文档。<br/>
74
+ </details>
75
+
76
+ <details close>
77
+ <summary> <b>2023-11-28</b> </summary>
78
+ - 更新dpo训练代码及模型权重。<br/>
79
+ </details>
80
+
81
+ <details close>
82
+ <summary> <b>2023-10-19</b> </summary>
83
+ - 项目开源, 开放模型权重供下载。 <br/>
84
+ </details>
85
+
86
+
87
+ # 二、🛠️ChatLM-0.2B-Chinese模型训练过程
88
+
89
+ ## 2.1 预训练数据集
90
+ 所有数据集均来自互联网公开的**单轮对话**数据集,经过数据清洗、格式化后保存为parquet文件。数据处理过程见`utils/raw_data_process.py`。主要数据集包括:
91
+
92
+ 1. 社区问答json版webtext2019zh-大规模高质量数据集,见:[nlp_chinese_corpus](https://github.com/brightmart/nlp_chinese_corpus)。共410万,清洗后剩余260万。
93
+ 2. baike_qa2019百科类问答,见:<https://aistudio.baidu.com/datasetdetail/107726>,共140万,清醒后剩余130万。
94
+ 3. 中国医药领域问答数据集,见:[Chinese-medical-dialogue-data](https://github.com/Toyhom/Chinese-medical-dialogue-data),共79万,清洗后剩余79万。
95
+ 4. ~~金融行业问答数据,见:<https://zhuanlan.zhihu.com/p/609821974>,共77万,清洗后剩余52万。~~**数据质量太差,未采用。**
96
+ 5. 知乎问答数据,见:[Zhihu-KOL](https://huggingface.co/datasets/wangrui6/Zhihu-KOL),共100万行,清洗后剩余97万行。
97
+ 6. belle开源的指令训练数据,介绍:[BELLE](https://github.com/LianjiaTech/BELLE),下载:[BelleGroup](https://huggingface.co/BelleGroup),仅选取`Belle_open_source_1M`、`train_2M_CN`、及`train_3.5M_CN`中部分回答较短、不含复杂表格结构、翻译任务(没做英文词表)的数据,共370万行,清洗后剩余338万行。
98
+ 7. 维基百科(Wikipedia)词条数据,将词条拼凑为提示语,百科的前`N`个词为回答,使用`202309`的百科数据,清洗后剩余119万的词条提示语和回答。Wiki下载:[zhwiki](https://dumps.wikimedia.org/zhwiki/),将下载的bz2文件转换为wiki.txt参考:[WikiExtractor](https://github.com/apertium/WikiExtractor)。
99
+
100
+ 数据集总数量1023万:Text-to-Text预训练集:930万,评估集:2.5万(因为解码较慢,所以没有把评估集设置太大)。~~测试集:90万。~~
101
+ SFT微调和DPO优化数据集见下文。
102
+
103
+ ## 2.2 模型
104
+ T5模型(Text-to-Text Transfer Transformer),详情见论文: [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683)。
105
+
106
+ 模型源码来自huggingface,见:[T5ForConditionalGeneration](https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1557)。
107
+
108
+ 模型配置见[model_config.json](https://huggingface.co/charent/ChatLM-mini-Chinese/blob/main/config.json),官方的`T5-base`:`encoder layer`和`decoder layer `均为为12层,本项目这两个参数修改为10层。
109
+
110
+ 模型参数:0.2B。词表大小:29298,仅包含中文和少量英文。
111
+
112
+ ## 2.3 训练过程
113
+ 硬件:
114
+ ```bash
115
+ # 预训练阶段:
116
+ CPU: 28 vCPU Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz
117
+ 内存:60 GB
118
+ 显卡:RTX A5000(24GB) * 2
119
+
120
+ # sft及dpo阶段:
121
+ CPU: Intel(R) i5-13600k @ 5.1GHz
122
+ 内存:32 GB
123
+ 显卡:NVIDIA GeForce RTX 4060 Ti 16GB * 1
124
+ ```
125
+ 1. **tokenizer 训练**: 现有`tokenizer`训练库遇到大语料时存在OOM问题,故全量语料按照类似`BPE`的方法根据词频合并、构造词库,运行耗时半天。
126
+
127
+ 2. **Text-to-Text 预训练**:学习率为`1e-4`到`5e-3`的动态学习率,预训练时间为8天。训练损失:
128
+
129
+ ![traing loss](img/train_loss.png)
130
+
131
+ 3. **prompt监督微调(SFT)**:使用`belle`指令训练数据集(指令和回答长度都在512以下),学习率为`1e-7`到`5e-5`的动态学习率,微调时间2天。微调损失:
132
+
133
+ ![finetune loss](img/sft_loss.png)
134
+
135
+ 4. **dpo直接偏好优化(RLHF)**:数据集[alpaca-gpt4-data-zh](https://huggingface.co/datasets/c-s-ale/alpaca-gpt4-data-zh)作为`chosen`文本,步骤`2`中SFT模型对数据集中的prompt做批量`generate`,得到`rejected`文本,耗时1天,dpo全量偏好优化,学习率`le-5`,半精度`fp16`,共`2`个`epoch`,耗时3h。dpo损失:
136
+
137
+ ![dpo loss](img/dpo_loss.png)
138
+
139
+ ## 2.4 对话效果展示
140
+ ### 2.4.1 stream chat
141
+ 默认使用`huggingface transformers`的 `TextIteratorStreamer`实现流式对话,只支持`greedy search`,如果需要`beam sample`等其他生成方式,请将`cli_demo.py`的`stream_chat`参数修改为`False`。
142
+ ![](./img/stream_chat.gif)
143
+
144
+ ### 2.4.2 对话展示
145
+ ![](./img/show1.png)
146
+
147
+ 存在问题:预训练数据集只有900多万,模型参数也仅0.2B,不能涵盖所有方面,会有答非所问、废话生成器的情况。
148
+
149
+ # 三、📑使用说明
150
+
151
+ ## 3.1 快速开始:
152
+ 如果无法连接huggingface,请使用`modelscope.snapshot_download`从modelscope下载模型文件。
153
+ ```python
154
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
155
+ import torch
156
+
157
+ model_id = 'charent/ChatLM-mini-Chinese'
158
+
159
+ # 如果无法连接huggingface,打开以下两行代码的注释,将从modelscope下载模型文件,模型文件保存到'./model_save'目录
160
+ # from modelscope import snapshot_download
161
+ # model_id = snapshot_download(model_id, cache_dir='./model_save')
162
+
163
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
164
+
165
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
166
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, trust_remote_code=True).to(device)
167
+
168
+ txt = '如何评价Apple这家公司?'
169
+
170
+ encode_ids = tokenizer([txt])
171
+ input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])
172
+
173
+ outs = model.my_generate(
174
+ input_ids=input_ids.to(device),
175
+ attention_mask=attention_mask.to(device),
176
+ max_seq_len=256,
177
+ search_type='beam',
178
+ )
179
+
180
+ outs_txt = tokenizer.batch_decode(outs.cpu().numpy(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
181
+ print(outs_txt[0])
182
+ ```
183
+ ```txt
184
+ Apple是一家专注于设计和用户体验的公司,其产品在设计上注重简约、流畅和功能性,而在用户体验方面则注重用户的反馈和使用体验。作为一家领先的科技公司,苹果公司一直致力于为用户提供最优质的产品和服务,不断推陈出新,不断创新和改进,以满足不断变化的市场需求。
185
+ 在iPhone、iPad和Mac等产品上,苹果公司一直保持着创新的态度,不断推出新的功能和设计,为用户提供更好的使用体验。在iPad上推出的iPad Pro和iPod touch等产品,也一直保持着优秀的用户体验。
186
+ 此外,苹果公司还致力于开发和销售软件和服务,例如iTunes、iCloud和App Store等,这些产品在市场上也获得了广泛的认可和好评。
187
+ 总的来说,苹果公司在设计、用户体验和产品创新方面都做得非常出色,为用户带来了许多便利和惊喜。
188
+
189
+ ```
190
+
191
+ ## 3.2 从克隆仓库代码开始
192
+
193
+ > [!CAUTION]
194
+ > 本项目模型为`TextToText`模型,在预训练、SFT、RLFH阶段的`prompt`、`response`等字段,请务必加上`[EOS]`序列结束标记。
195
+
196
+
197
+ ### 3.2.1 克隆项目:
198
+ ```bash
199
+ git clone --depth 1 https://github.com/charent/ChatLM-mini-Chinese.git
200
+
201
+ cd ChatLM-mini-Chinese
202
+ ```
203
+ ### 3.2.2 安装依赖
204
+
205
+ 本项目推荐使用`python 3.10`,过老的python版本可能不兼容所依赖的第三方库。
206
+
207
+ pip安装:
208
+ ```bash
209
+ pip install -r ./requirements.txt
210
+ ```
211
+
212
+ 如果pip安装了CPU版本的pytorch,可以通过下面的命令安装CUDA版本的pytorch:
213
+ ```bash
214
+ # pip 安装torch + cu118
215
+ pip3 install torch --index-url https://download.pytorch.org/whl/cu118
216
+ ```
217
+
218
+ conda安装:
219
+ ```bash
220
+ conda install --yes --file ./requirements.txt
221
+ ```
222
+
223
+ ### 3.2.3 下载预训练模型及模型配置文件
224
+
225
+ 用`git`命令从`Hugging Face Hub`下载模型权重及配置文件,需要先安装[Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage),然后运行:
226
+
227
+ ```bash
228
+ # 使用git命令下载huggingface模型,先安装[Git LFS],否则下载的模型文件不可用
229
+ git clone --depth 1 https://huggingface.co/charent/ChatLM-mini-Chinese
230
+
231
+ # 如果无法连接huggingface,请从modelscope下载
232
+ git clone --depth 1 https://www.modelscope.cn/charent/ChatLM-mini-Chinese.git
233
+
234
+ mv ChatLM-mini-Chinese model_save
235
+ ```
236
+
237
+ 也可以直接从`Hugging Face Hub`仓库[ChatLM-Chinese-0.2B](https://huggingface.co/charent/ChatLM-mini-Chinese)手工下载,将下载的文件移动到`model_save`目录下即可。
238
+
239
+ ## 3.3 Tokenizer训练
240
+
241
+ 1. 准备txt语料
242
+
243
+ 语料要求尽可能全,建议添加多个语料,如百科、代码、论文、博客、对话等。
244
+
245
+ 本项目以wiki中文百科为主。获取中文wiki语料方法:中文Wiki下载地址:[zhwiki](https://dumps.wikimedia.org/zhwiki/),下载`zhwiki-[存档日期]-pages-articles-multistream.xml.bz2`文件,大概2.7GB, 将下载的bz2文件转换为wiki.txt参考:[WikiExtractor](https://github.com/apertium/WikiExtractor),再利用python的`OpenCC`库转换为简体中文,最后将得到的`wiki.simple.txt`放到项目根目录的`data`目录下即可。多个语料请自行合并为一个`txt`文件。
246
+
247
+ 由于训练tokenizer非常耗内存,如果你的语料非常大(合并后的`txt`文件超过2G),建议对语料按照类别、比例进行采样,以减少训练时间和内存消耗。训练1.7GB的`txt`文件需要消耗48GB左右的内存(预估的,我只有32GB,频繁触发swap,电脑卡了好久T_T),13600k cpu耗时1小时左右。
248
+
249
+ 2. 训练tokenizer
250
+
251
+ `char level`和`byte level`的区别如下(具体使用上的区别请自行检索资料)。默认训练`char level`的tokenizer,如果需要`byte level`,在`train_tokenizer.py`中设置`token_type='byte'`即可。
252
+
253
+ ```python
254
+ # 原始文本
255
+ txt = '这是一段中英混输的句子, (chinese and English, here are words.)'
256
+
257
+ tokens = charlevel_tokenizer.tokenize(txt)
258
+ print(tokens)
259
+ # char level tokens输出
260
+ # ['▁这是', '一段', '中英', '混', '输', '的', '句子', '▁,', '▁(', '▁ch', 'inese', '▁and', '▁Eng', 'lish', '▁,', '▁h', 'ere', '▁', 'are', '▁w', 'ord', 's', '▁.', '▁)']
261
+
262
+ tokens = bytelevel_tokenizer.tokenize(txt)
263
+ print(tokens)
264
+ # byte level tokens输出
265
+ # ['Ġè¿Ļæĺ¯', 'ä¸Ģ段', 'ä¸Ńèĭ±', 'æ··', 'è¾ĵ', 'çļĦ', 'åı¥åŃIJ', 'Ġ,', 'Ġ(', 'Ġch', 'inese', 'Ġand', 'ĠEng', 'lish', 'Ġ,', 'Ġh', 'ere', 'Ġare', 'Ġw', 'ord', 's', 'Ġ.', 'Ġ)']
266
+ ```
267
+ 开始训练:
268
+ ```python
269
+ # 确保你的训练语料`txt`文件已经data目录下
270
+ python train_tokenizer.py
271
+ ```
272
+
273
+ ## 3.4 Text-to-Text 预���练
274
+
275
+ 1. 预训练数据集示例
276
+ ```json
277
+ {
278
+ "prompt": "对于花园街,你有什么了解或看法吗?",
279
+ "response": "花园街(是香港油尖旺区的一条富有特色的街道,位于九龙旺角东部,北至界限街,南至登打士街,与通菜街及洗衣街等街道平行。现时这条街道是香港著名的购物区之一。位于亚皆老街以南的一段花园街,也就是\"波鞋街\"整条街约150米长,有50多间售卖运动鞋和运动用品的店舖。旺角道至太子道西一段则为排档区,售卖成衣、蔬菜和水果等。花园街一共分成三段。明清时代,花园街是芒角村栽种花卉的地方。此外,根据历史专家郑宝鸿的考证:花园街曾是1910年代东方殷琴拿烟厂的花园。纵火案。自2005年起,花园街一带最少发生5宗纵火案,当中4宗涉及排档起火。2010年。2010年12月6日,花园街222号一个卖鞋的排档于凌晨5时许首先起火,浓烟涌往旁边住宅大厦,消防接报4"
280
+ }
281
+ ```
282
+
283
+ 2. jupyter-lab 或者 jupyter notebook:
284
+
285
+ 见文件`train.ipynb`,推荐使用jupyter-lab,避免考虑与服务器断开后终端进程被杀的情况。
286
+
287
+ 3. 控制台:
288
+
289
+ 控制台训练需要考虑连接断开后进程被杀的,推荐使用进程守护工具`Supervisor`或者`screen`建立连接会话。
290
+
291
+ 首先要配置`accelerate`,执行以下命令, 根据提示选择即可,参考`accelerate.yaml`,*注意:DeepSpeed在Windows安装比较麻烦*。
292
+ ```bash
293
+ accelerate config
294
+ ```
295
+
296
+ 开始训练,如果要使用工程提供的配置请在下面的命令`accelerate launch`后加上参数`--config_file ./accelerate.yaml`,*该配置按照单机2xGPU配置。*
297
+
298
+ *预训练有两个脚本,本项目实现的trainer对应`train.py`,huggingface实现的trainer对应`pre_train.py`,用哪个都可以,效果一致。本项目实现的trainer训练信息展示更美观、更容易修改训练细节(如损失函数,日志记录等),均支持断点继续训练,本项目实现的trainer支持在任意位置断点后继续训练,按`ctrl+c`退出脚本时会保存断点信息。*
299
+
300
+ 单机单卡:
301
+ ```bash
302
+ # 本项目实现的trainer
303
+ accelerate launch ./train.py train
304
+
305
+ # 或者使用 huggingface trainer
306
+ python pre_train.py
307
+ ```
308
+
309
+ 单机多卡:
310
+ `2`为显卡数量,请根据自己的实际情况修改。
311
+ ```bash
312
+ # 本项目实现的trainer
313
+ accelerate launch --multi_gpu --num_processes 2 ./train.py train
314
+
315
+ # 或者使用 huggingface trainer
316
+ accelerate launch --multi_gpu --num_processes 2 pre_train.py
317
+ ```
318
+
319
+ 从断点处继续训练:
320
+ ```bash
321
+ # 本项目实现的trainer
322
+ accelerate launch --multi_gpu --num_processes 2 ./train.py train --is_keep_training=True
323
+
324
+ # 或者使用 huggingface trainer
325
+ # 需要在`pre_train.py`中的`train`函数添加`resume_from_checkpoint=True`
326
+ accelerate launch --multi_gpu --num_processes 2 pre_train.py
327
+ ```
328
+
329
+ ## 3.5 SFT微调
330
+ SFT数据集全部来自[BELLE](https://github.com/LianjiaTech/BELLE)大佬的贡献,感谢。SFT数据集分别为:[generated_chat_0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)、[train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)和[train_2M_CN](https://huggingface.co/datasets/BelleGroup/train_2M_CN),清洗后剩余约137万行。
331
+ sft指令微调数据集示例:
332
+ ```json
333
+ {
334
+ "prompt": "解释什么是欧洲启示录",
335
+ "response": "欧洲启示录(The Book of Revelation)是新约圣经的最后一卷书,也被称为《启示录》、《默示录》或《约翰默示录》。这本书从宗教的角度描述了世界末日的来临,以及上帝对世界的审判和拯救。 书中的主题包括来临的基督的荣耀,上帝对人性的惩罚和拯救,以及魔鬼和邪恶力量的存在。欧洲启示录是一个充满象征和暗示的文本,对于解读和理解有许多不同的方法和观点。"
336
+ }
337
+ ```
338
+
339
+ 参考`data`目录下的示例`parquet`文件制作自己的数据集,数据集格式:`parquet`文件分两列,一列`prompt`文本,表示提示语,一列`response`文本,表示期待的模型输出。
340
+ 微调细节见`model/trainer.py`下的`train`方法, `is_finetune`设置为`True`时,将进行微调,微调默认会冻结embedding层和encoder层,只训练decoder层。如需要冻结其他参数,请自行调整代码。
341
+
342
+ 运行SFT微调:
343
+ ``` bash
344
+ # 本项目实现的trainer, 添加参数`--is_finetune=True`即可, 参数`--is_keep_training=True`可从任意断点处继续训练
345
+ accelerate launch --multi_gpu --num_processes 2 ./train.py --is_finetune=True
346
+
347
+ # 或者使用 huggingface trainer, 多GPU请用accelerate launch --multi_gpu --num_processes gpu个数 sft_train.py
348
+ python sft_train.py
349
+ ```
350
+
351
+ ## 3.6 RLHF(强化学习人类反馈优化方法)
352
+
353
+ 偏好方法这里介绍常见的两种:PPO和DPO,具体实现请自行搜索论文���博客。
354
+
355
+ 1. PPO方法(近似偏好优化,Proximal Policy Optimization)
356
+ 步骤1:使用微调数据集做有监督微调(SFT, Supervised Finetuning)。
357
+ 步骤2:使用偏好数据集(一个prompt至少包含2个回复,一个想要的回复,一个不想要的回复。多个回复可以按照分数排序,最想要的分数最高)训练奖励模型(RM, Reward Model)。可使用`peft`库快速搭建Lora奖励模型。
358
+ 步骤3:利用RM对SFT模型进行有监督PPO训练,使得模型满足偏好。
359
+
360
+ 2. 使用DPO(直接偏好优化,Direct Preference Optimization)微调(**本项目采用DPO微调方法,比较节省显存**)
361
+ 在获得SFT模型的基础上,无需训练奖励模型,取得正向回答(chosen)和负向回答(rejected)即可开始微调。微调的`chosen`文本来自原数据集[alpaca-gpt4-data-zh](https://huggingface.co/datasets/c-s-ale/alpaca-gpt4-data-zh),拒绝文本`rejected`来自SFT微调1个epoch后的模型输出,另外两个数据集:[huozi_rlhf_data_json](https://huggingface.co/datasets/Skepsun/huozi_rlhf_data_json)和[rlhf-reward-single-round-trans_chinese](https://huggingface.co/datasets/beyond/rlhf-reward-single-round-trans_chinese),合并后共8万条dpo数据。
362
+
363
+ dpo数据集处理过程见`utils/dpo_data_process.py`。
364
+
365
+ DPO偏好优化数据集示例:
366
+ ```json
367
+ {
368
+ "prompt": "为给定的产品创建一个创意标语。,输入:可重复使用的水瓶。",
369
+ "chosen": "\"保护地球,从拥有可重复使用的水瓶开始!\"",
370
+ "rejected": "\"让你的水瓶成为你的生活伴侣,使用可重复使用的水瓶,让你的水瓶成为你的伙伴\""
371
+ }
372
+ ```
373
+
374
+ 运行偏好优化:
375
+ ``` bash
376
+ # 多GPU请用accelerate launch --multi_gpu --num_processes gpu个数 dpo_train.py
377
+ python dpo_train.py
378
+ ```
379
+
380
+ ## 3.7 推理
381
+ 确保`model_save`目录下有以下文件,这些文件都可以在`Hugging Face Hub`仓库[ChatLM-Chinese-0.2B](https://huggingface.co/charent/ChatLM-mini-Chinese)中找到:
382
+ ```bash
383
+ ChatLM-mini-Chinese
384
+ ├─model_save
385
+ | ├─config.json
386
+ | ├─configuration_chat_model.py
387
+ | ├─generation_config.json
388
+ | ├─model.safetensors
389
+ | ├─modeling_chat_model.py
390
+ | ├─special_tokens_map.json
391
+ | ├─tokenizer.json
392
+ | └─tokenizer_config.json
393
+ ```
394
+
395
+ 1. 控制台运行:
396
+ ```bash
397
+ python cli_demo.py
398
+ ```
399
+
400
+ 2. API调用
401
+ ```bash
402
+ python api_demo.py
403
+ ```
404
+
405
+ API调用示例:
406
+ ```bash
407
+ curl --location '127.0.0.1:8812/api/chat' \
408
+ --header 'Content-Type: application/json' \
409
+ --header 'Authorization: Bearer Bearer' \
410
+ --data '{
411
+ "input_txt": "感冒了要怎么办"
412
+ }'
413
+ ```
414
+ ![api demo](./img/api_example.png)
415
+
416
+ ## 3.8 下游任务微调
417
+
418
+ 这里以文本中三元组信息为例,做下游微调。该任务的传统深度学习抽取方法见仓库[pytorch_IE_model](https://github.com/charent/pytorch_IE_model)。抽取出一段文本中所有的三元组,如句子`《写生随笔》是冶金工业2006年出版的图书,作者是张来亮`,抽取出三元组`(写生随笔,作者,张来亮)`和`(写生随笔,出版社,冶金工业)`。
419
+
420
+ 原始数据集为:[百度三元组抽取数据集](https://aistudio.baidu.com/datasetdetail/11384)。加工得到的微调数据集格式示例:
421
+ ```json
422
+ {
423
+ "prompt": "请抽取出给定句子中的所有三元组。给定句子:《家乡的月亮》是宋雪莱演唱的一首歌曲,所属专辑是《久违的哥们》",
424
+ "response": "[(家乡的月亮,歌手,宋雪莱),(家乡的月亮,所属专辑,久违的哥们)]"
425
+ }
426
+ ```
427
+
428
+ 可以直接使用`sft_train.py`脚本进行微调,脚本[finetune_IE_task.ipynb](./finetune_examples/info_extract/finetune_IE_task.ipynb)里面包含详细的解码过程。训练数据集约`17000`条,学习率`5e-5`,训练epoch`5`。微调后其他任务的对话能力也没有消失。
429
+
430
+ ![信息抽取任务微调后的对话能力](./img/ie_task_chat.png)
431
+
432
+ 微调效果:
433
+ 将`百度三元组抽取数据集`公开的`dev`数据集作为测试集,对比传统方法[pytorch_IE_model](https://github.com/charent/pytorch_IE_model)。
434
+
435
+ | 模型 | F1分数 | 精确率P | 召回率R |
436
+ | :--- | :----: | :---: | :---: |
437
+ | ChatLM-Chinese-0.2B微调 | 0.74 | 0.75 | 0.73 |
438
+ | ChatLM-Chinese-0.2B无预训练| 0.51 | 0.53 | 0.49 |
439
+ | 传统深度学习方法 | 0.80 | 0.79 | 80.1 |
440
+
441
+ 备注:`ChatLM-Chinese-0.2B无预训练`指直接初始化随机参数,开始训练,学习率`1e-4`,其他参数和微调一致。
442
+
443
+ ## 3.9 C-Eval分数
444
+ 模型本身没有使用较大的数据集训练,也没有针对回答选择题的指令做微调,C-Eval分数基本上是baseline水平,有需要的可以当个参考。C-Eval评测代码见:`eval/c_eavl.ipynb`
445
+
446
+ | category | correct | question_count| accuracy |
447
+ | :--- | :----: | :---: | :---: |
448
+ | Humanities | 63 | 257 | 24.51% |
449
+ | Other | 89 | 384 | 23.18% |
450
+ | STEM | 89 | 430 | 20.70% |
451
+ | Social Science | 72 | 275 | 26.18% |
452
+
453
+ # 四、🎓引用
454
+ 如果你觉得本项目对你有所帮助,欢迎引用。
455
+ ```conf
456
+ @misc{Charent2023,
457
+ author={Charent Chen},
458
+ title={A small chinese chat language model with 0.2B parameters base on T5},
459
+ year={2023},
460
+ publisher = {GitHub},
461
+ journal = {GitHub repository},
462
+ howpublished = {\url{https://github.com/charent/ChatLM-mini-Chinese}},
463
+ }
464
+ ```
465
+
466
+ # 五、🤔其他事项
467
+ 本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。
468
+
469
+ <!-- # 提示
470
+ ```bash
471
+ # 导出项目依赖的包:
472
+ pipreqs --encoding "utf-8" --force
473
+ ``` -->
474
+
accelerate.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ deepspeed_config:
4
+ gradient_accumulation_steps: 8
5
+ gradient_clipping: 1.0
6
+ offload_optimizer_device: cpu
7
+ offload_param_device: cpu
8
+ zero3_init_flag: false
9
+ zero3_save_16bit_model: false
10
+ zero_stage: 2
11
+ distributed_type: DEEPSPEED
12
+ downcast_bf16: 'no'
13
+ dynamo_config:
14
+ dynamo_backend: EAGER
15
+ machine_rank: 0
16
+ main_training_function: main
17
+ mixed_precision: bf16
18
+ num_machines: 1
19
+ num_processes: 2
20
+ rdzv_backend: static
21
+ same_network: true
22
+ tpu_env: []
23
+ tpu_use_cluster: false
24
+ tpu_use_sudo: false
25
+ use_cpu: false
api_demo.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Union
3
+
4
+ import uvicorn
5
+ from fastapi import FastAPI, Depends, status
6
+ from fastapi.security import OAuth2PasswordBearer
7
+ from fastapi.exceptions import HTTPException
8
+ from pydantic import BaseModel
9
+
10
+ from model.infer import ChatBot
11
+ from config import InferConfig
12
+
13
+ CONFIG = InferConfig()
14
+ chat_bot = ChatBot(infer_config=CONFIG)
15
+
16
+ #==============================================================
17
+ # api 配置
18
+
19
+ # api根目录
20
+ ROOT = '/api'
21
+
22
+ # api key
23
+ USE_AUTH = False if len(CONFIG.api_key) == 0 else True
24
+ SECRET_KEY = CONFIG.api_key
25
+
26
+ app = FastAPI()
27
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
28
+
29
+ #==============================================================
30
+
31
+ """
32
+ post请求地址:http://127.0.0.1:8812/api/chat
33
+ 需要添加Authorization头,bodyjson格式,示例:
34
+ {
35
+ "input_txt": "感冒了要怎么办"
36
+ }
37
+ """
38
+
39
+ async def api_key_auth(token: str = Depends(oauth2_scheme)) -> Union[None, bool]:
40
+ """
41
+ 验证post请求的key是否和服务器的key一致
42
+ 需要在请求头加上 Authorization: Bearer SECRET_KEY
43
+ """
44
+ if not USE_AUTH:
45
+ return None # return None if not auth
46
+
47
+ if token == SECRET_KEY:
48
+ return None # return None if auth success
49
+
50
+ # 验证出错
51
+ raise HTTPException(
52
+ status_code=status.HTTP_401_UNAUTHORIZED,
53
+ detail="api认证未通过,请检查认证方式和token!",
54
+ headers={"WWW-Authenticate": "Bearer"},
55
+ )
56
+
57
+ # pos请求json
58
+ class ChatInput(BaseModel):
59
+ input_txt: str
60
+
61
+
62
+ @app.post(ROOT + "/chat")
63
+ async def chat(post_data: ChatInput, authority: str = Depends(api_key_auth)) -> dict:
64
+ """
65
+ post 输入: {'input_txt': '输入的文本'}
66
+ response: {'response': 'chatbot文本'}
67
+ """
68
+ input_txt = post_data.input_txt
69
+ if len(input_txt) == 0:
70
+ raise HTTPException(
71
+ status_code=status.HTTP_406_NOT_ACCEPTABLE,
72
+ detail="input_txt length = 0 is not allow!",
73
+ headers={"WWW-Authenticate": "Bearer"},
74
+ )
75
+
76
+ outs = chat_bot.chat(input_txt)
77
+
78
+ if len(outs) == 0:
79
+ outs = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
80
+
81
+ return {'response': outs}
82
+
83
+ if __name__ == '__main__':
84
+
85
+ # 加上reload参数(reload=True)时,多进程设置无效
86
+ # workers = max(multiprocessing.cpu_count() * CONFIG.getint('uvicorn','process_worker'), 1)
87
+ workers = max(CONFIG.workers, 1)
88
+ print('启动的进程个数:{}'.format(workers))
89
+
90
+ uvicorn.run(
91
+ 'api_demo:app',
92
+ host=CONFIG.host,
93
+ port=CONFIG.port,
94
+ reload=CONFIG.reload,
95
+ workers=workers,
96
+ log_level='info'
97
+ )
98
+
99
+
100
+ # 服务方式启动:
101
+ # 命令行输入:uvicorn api_demo:app --host 0.0.0.0 --port 8094 --workers 8
102
+ # api_demo:api_demo.py文件
103
+ # app:app = FastAPI() 在main.py内创建的对象。
104
+ # --reload:在代码更改后重新启动服务器。 只有在开发时才使用这个参数,此时多进程设置会无效
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import platform
4
+ import os
5
+ import time
6
+ from threading import Thread
7
+
8
+ from rich.text import Text
9
+ from rich.live import Live
10
+
11
+ from model.infer import ChatBot
12
+ from config import InferConfig
13
+
14
+ infer_config = InferConfig()
15
+ chat_bot = ChatBot(infer_config=infer_config)
16
+ # streamer = chat_bot.chat("你好")
17
+ # print(streamer)
18
+ # streamer = chat_bot.stream_chat("你好")
19
+ # welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
20
+ # def build_prompt(history: list[list[str]]) -> str:
21
+ # prompt = welcome_txt
22
+ # for query, response in history:
23
+ # prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
24
+ # prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
25
+ # return prompt
26
+ # print(build_prompt(streamer))
27
+
28
+ def greet(name):
29
+ streamer = chat_bot.chat("你好")
30
+ return streamer
31
+ # return "Hello " + name + "!!"
32
+
33
+
34
+
35
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
36
+
37
+ iface.launch()
cli_demo.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import os
3
+ import time
4
+ from threading import Thread
5
+
6
+ from rich.text import Text
7
+ from rich.live import Live
8
+
9
+ from model.infer import ChatBot
10
+ from config import InferConfig
11
+
12
+ infer_config = InferConfig()
13
+ chat_bot = ChatBot(infer_config=infer_config)
14
+
15
+ clear_cmd = 'cls' if platform.system().lower() == 'windows' else 'clear'
16
+
17
+ welcome_txt = '欢迎使用ChatBot,输入`exit`退出,输入`cls`清屏。\n'
18
+ print(welcome_txt)
19
+
20
+ def build_prompt(history: list[list[str]]) -> str:
21
+ prompt = welcome_txt
22
+ for query, response in history:
23
+ prompt += '\n\033[0;33;40m用户:\033[0m{}'.format(query)
24
+ prompt += '\n\033[0;32;40mChatBot:\033[0m\n{}\n'.format(response)
25
+ return prompt
26
+
27
+ STOP_CIRCLE: bool=False
28
+ def circle_print(total_time: int=60) -> None:
29
+ global STOP_CIRCLE
30
+ '''非stream chat打印忙碌状态
31
+ '''
32
+ list_circle = ["\\", "|", "/", "—"]
33
+ for i in range(total_time * 4):
34
+ time.sleep(0.25)
35
+ print("\r{}".format(list_circle[i % 4]), end="", flush=True)
36
+
37
+ if STOP_CIRCLE: break
38
+
39
+ print("\r", end='', flush=True)
40
+
41
+
42
+ def chat(stream: bool=True) -> None:
43
+ global STOP_CIRCLE
44
+ history = []
45
+ turn_count = 0
46
+
47
+ while True:
48
+ print('\r\033[0;33;40m用户:\033[0m', end='', flush=True)
49
+ input_txt = input()
50
+
51
+ if len(input_txt) == 0:
52
+ print('请输入问题')
53
+ continue
54
+
55
+ # 退出
56
+ if input_txt.lower() == 'exit':
57
+ break
58
+
59
+ # 清屏
60
+ if input_txt.lower() == 'cls':
61
+ history = []
62
+ turn_count = 0
63
+ os.system(clear_cmd)
64
+ print(welcome_txt)
65
+ continue
66
+
67
+ if not stream:
68
+ STOP_CIRCLE = False
69
+ thread = Thread(target=circle_print)
70
+ thread.start()
71
+
72
+ outs = chat_bot.chat(input_txt)
73
+
74
+ STOP_CIRCLE = True
75
+ thread.join()
76
+
77
+ print("\r\033[0;32;40mChatBot:\033[0m\n{}\n\n".format(outs), end='')
78
+
79
+ continue
80
+
81
+ history.append([input_txt, ''])
82
+ stream_txt = []
83
+ streamer = chat_bot.stream_chat(input_txt)
84
+ rich_text = Text()
85
+
86
+ print("\r\033[0;32;40mChatBot:\033[0m\n", end='')
87
+
88
+ with Live(rich_text, refresh_per_second=15) as live:
89
+ for i, word in enumerate(streamer):
90
+ rich_text.append(word)
91
+ stream_txt.append(word)
92
+
93
+ stream_txt = ''.join(stream_txt)
94
+
95
+ if len(stream_txt) == 0:
96
+ stream_txt = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
97
+
98
+ history[turn_count][1] = stream_txt
99
+
100
+ os.system(clear_cmd)
101
+ print(build_prompt(history), flush=True)
102
+ turn_count += 1
103
+
104
+ if __name__ == '__main__':
105
+ chat(stream=True)
config.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from os.path import dirname, abspath
3
+
4
+ # replace '\' on windows to '/'
5
+ PROJECT_ROOT: str = '/'.join(abspath(dirname(__file__)).split('\\')) if '\\' in abspath(dirname(__file__)) else abspath(dirname(__file__))
6
+
7
+ # ===================================================================================
8
+ # 以下为推断的配置
9
+ @dataclass
10
+ class InferConfig:
11
+ max_seq_len: int = 320 # 回答的最大长度
12
+ mixed_precision: str = "bf16" # 混合精度 ''no','fp16','bf16' or 'fp8'
13
+
14
+ # 全量DPO模型文件, tokenizer文件和model权重放在同一个文件夹
15
+ model_dir: str = PROJECT_ROOT + '/model_save/'
16
+
17
+ # lora PDO 合并后的模型文件
18
+ # model_file: str = PROJECT_ROOT + '/model_save/chat_small_t5.best.dpo.lora_merged.bin'
19
+
20
+ # this confing for api demo:
21
+ api_key: str = ""
22
+ host: str = '127.0.0.1'
23
+ port: int = 8812
24
+ reload: bool = True
25
+ workers: int = 1
26
+ log_level: str = 'info'
27
+
28
+
29
+ #===================================================================================
30
+ # 以下为dpo训练配置
31
+ @dataclass
32
+ class DpoConfig:
33
+ max_seq_len: int = 512 + 8 # 8 for eos token
34
+ sft_model_file: str = PROJECT_ROOT + '/model_save/'
35
+
36
+ tokenizer_dir: str = PROJECT_ROOT + '/model_save/' # tokenizer一般和model权重放在同一个文件夹
37
+
38
+ dpo_train_file: str = PROJECT_ROOT + '/data/my_dpo_data.json'
39
+ dpo_eval_file: str = PROJECT_ROOT + '/data/my_dpo_eval.json'
40
+
41
+ adapter_file: str = PROJECT_ROOT + '/data/dpo/adapter_model.safetensors'
42
+ log_dir: str = PROJECT_ROOT + '/logs/'
43
+
44
+ per_device_train_batch_size: int = 4
45
+ num_train_epochs: int = 4
46
+ gradient_accumulation_steps: int = 8
47
+ learning_rate: float = 1e-5
48
+ logging_first_step: bool = True
49
+ logging_steps: int = 20
50
+ save_steps: int = 2000
51
+ output_dir: str = PROJECT_ROOT + '/model_save/dpo'
52
+ warmup_steps: int = 1000
53
+ fp16: bool = True
54
+ seed: int = 23333
55
+ beta: float = 0.1
56
+
57
+
58
+
59
+ # 以下为sft配置
60
+ @dataclass
61
+ class SFTconfig:
62
+ max_seq_len: int = 384 + 8 # 8 for eos token
63
+
64
+ finetune_from_ckp_file = PROJECT_ROOT + '/model_save/'
65
+
66
+ tokenizer_dir: str = PROJECT_ROOT + '/model_save/' # tokenizer一般和model权重放在同一个文件夹
67
+ sft_train_file: str = PROJECT_ROOT + '/data/sft_train.json'
68
+
69
+ batch_size: int = 12
70
+ num_train_epochs: int = 4
71
+ save_steps: int = 5000
72
+ gradient_accumulation_steps: int = 4
73
+ learning_rate: float = 1e-5
74
+ logging_first_step: bool = True
75
+ logging_steps: int = 100
76
+ output_dir: str = PROJECT_ROOT + '/model_save/sft'
77
+ warmup_steps: int = 100
78
+ fp16: bool = True
79
+ seed: int = 23333
80
+
81
+
82
+ # ===================================================================================
83
+ # 以下为训练的配置
84
+ @dataclass
85
+ class TrainConfig:
86
+ epochs: int = 8
87
+ batch_size_per_gpu: int = 16
88
+
89
+ learn_rate: float = 0.0001 # 最大 div_factor * learn_rate
90
+ div_factor: int = 50
91
+
92
+ mixed_precision: str = "bf16" # 混合精度 ''no','fp16','bf16' or 'fp8'
93
+
94
+ # 注意:计算梯度时相当于batch_size * gradient_accumulation_steps,说人话就是梯度累积步数>1时,等于增大n倍的batch_size
95
+ gradient_accumulation_steps: int = 8 # 累积梯度更新步数
96
+
97
+ warmup_steps: int = 1024 # 模型参数预热步数,预热样本数=warmup_steps * batch_size * gradient_accumulation_steps
98
+
99
+ tokenizer_dir: str = PROJECT_ROOT + '/model_save/' # tokenizer一般和model权重放在同一个文件夹
100
+ model_file: str = PROJECT_ROOT + '/model_save/chat_small_t5.{}.bin'
101
+ model_config_file: str = PROJECT_ROOT + '/model_save/model_config.json'
102
+ train_file: str = PROJECT_ROOT + '/data/my_train_dataset.parquet'
103
+ validation_file: str = PROJECT_ROOT + '/data/my_valid_dataset.parquet'
104
+ test_file: str = PROJECT_ROOT + '/data/my_test_dataset.parquet'
105
+
106
+ # 从哪个模型开始微调,仅当traing 函数 is_finetune = True时生效
107
+ # 微调记得冻结某些层或者调低学习率
108
+ finetune_from_ckp_file = PROJECT_ROOT + '/model_save/chat_small_t5.best.bin'
109
+
110
+ # 训练状态保存,中断后可以从此处继续训练
111
+ train_state_dir: str = PROJECT_ROOT + '/model_save/train_latest_state'
112
+ output_dir: str = PROJECT_ROOT + '/model_save/pretrain'
113
+
114
+ logging_steps: int = 50
115
+ save_steps: int = 10000
116
+
117
+ # dataset_cache_dir: str = PROJECT_ROOT + '/data/.cache'
118
+ # trainer_log_file: str = PROJECT_ROOT + '/logs/trainer.log'
119
+
120
+ keep_latest_n_ckp: int = 8 # 训练过程中,最多保留多少个分数最好的模型文件
121
+
122
+ seed: int = 23333
123
+ dataloader_buffer_size: int = 50000
124
+ max_seq_len: int = 256 # 最大句子长度,默认:256
125
+
126
+
127
+ #======================================================================================
128
+ # 以下为模型的配置
129
+ @dataclass
130
+ class T5ModelConfig:
131
+
132
+ d_ff: int = 3072 # 全连接层维度
133
+
134
+ d_model: int = 768 # 词向量维度
135
+ num_heads: int = 12 # 注意力头数 d_model // num_heads == d_kv
136
+ d_kv: int = 64 # d_model // num_heads
137
+
138
+ num_decoder_layers: int = 10 # Transformer decoder 隐藏层层数
139
+ num_layers: int = 10 # Transformer encoder 隐藏层层数
data/my_test_dataset_2k.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8a99f671c9bf8dfbddf8a1aaf13decbf7eea440c07a2631e2c634ee6cd5dded
3
+ size 575315
data/my_train_dataset_3k.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cbe91a996f659e77d1047453686a6872ff5a5ce5a9f5026028d3edb6def6f4f9
3
+ size 855994
data/my_valid_dataset_1k.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfdd45edb8aeaf49089795cf208f04d9baea0922883e87c4fdd33af350029092
3
+ size 286692
dpo_train.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Dict, Optional
3
+ import time
4
+ import os
5
+
6
+ import pandas as pd
7
+ import torch
8
+ from datasets import Dataset, load_dataset
9
+ from transformers import PreTrainedTokenizerFast, TrainingArguments
10
+ from trl import DPOTrainer
11
+ from tokenizers import Tokenizer
12
+ from peft import LoraConfig, TaskType, PeftModel
13
+
14
+ from config import DpoConfig, T5ModelConfig
15
+ from model.chat_model import TextToTextModel
16
+ from utils.functions import get_T5_config
17
+
18
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
19
+
20
+ def get_dataset(split: str, file: str, cache_dir: str = '.cache') -> Dataset:
21
+ """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
22
+
23
+ The dataset is converted to a dictionary with the following structure:
24
+ {
25
+ 'prompt': List[str],
26
+ 'chosen': List[str],
27
+ 'rejected': List[str],
28
+ }
29
+ """
30
+ dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)
31
+
32
+ def split_prompt_and_responses(sample: dict) -> Dict[str, str]:
33
+ return {
34
+ # add an eos token for signal that end of sentence, using in generate.
35
+ "prompt": f"{sample['prompt']}[EOS]",
36
+ "chosen": f"{sample['chosen']}[EOS]",
37
+ "rejected": f"{sample['rejected']}[EOS]",
38
+ }
39
+
40
+ return dataset.map(split_prompt_and_responses).shuffle(2333)
41
+
42
+
43
+ def train_dpo(config: DpoConfig, peft_config: LoraConfig=None) -> None:
44
+
45
+ # step 1. 加载tokenizer
46
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
47
+
48
+ # step 2. 加载预训练模型
49
+ model_train, model_ref = None, None
50
+ if os.path.isdir(config.sft_model_file):
51
+ # 传入文件夹则 from_pretrained
52
+ model_train = TextToTextModel.from_pretrained(config.sft_model_file)
53
+ model_ref = TextToTextModel.from_pretrained(config.sft_model_file)
54
+ else:
55
+ # load_state_dict
56
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
57
+
58
+ model_train = TextToTextModel(t5_config)
59
+ model_train.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception
60
+
61
+ model_ref = TextToTextModel(t5_config)
62
+ model_ref.load_state_dict(torch.load(config.sft_model_file, map_location='cpu'))
63
+
64
+ # 4. 加载训练数据集
65
+ train_dataset = get_dataset("train", file=config.dpo_train_file)
66
+
67
+ # 5. 加载评估数据集
68
+ # eval_dataset = get_dataset("train", file=config.dpo_eval_file)
69
+ eval_dataset = None
70
+
71
+ # 6. 初始化训练参数
72
+ training_args = TrainingArguments(
73
+ per_device_train_batch_size=config.per_device_train_batch_size,
74
+ num_train_epochs=config.num_train_epochs,
75
+ auto_find_batch_size=True,
76
+ remove_unused_columns=False,
77
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
78
+ learning_rate=config.learning_rate,
79
+ logging_first_step=True,
80
+ logging_steps=config.logging_steps,
81
+ save_steps=config.save_steps,
82
+ output_dir=config.output_dir,
83
+ optim="adafactor",
84
+ report_to="tensorboard",
85
+ log_level='info',
86
+ warmup_steps=config.warmup_steps,
87
+ bf16=False,
88
+ fp16=config.fp16,
89
+ seed=config.seed,
90
+ logging_dir=config.log_dir,
91
+ )
92
+
93
+ # 7. 初始化 DPO trainer
94
+ dpo_trainer = DPOTrainer(
95
+ model_train,
96
+ model_ref,
97
+ peft_config=peft_config,
98
+ args=training_args,
99
+ beta=config.beta,
100
+ train_dataset=train_dataset,
101
+ eval_dataset=eval_dataset,
102
+ tokenizer=tokenizer,
103
+ max_length=config.max_seq_len,
104
+ max_target_length=config.max_seq_len,
105
+ max_prompt_length=config.max_seq_len,
106
+ generate_during_eval=True,
107
+ is_encoder_decoder=True,
108
+ )
109
+
110
+ # 8. 训练
111
+ dpo_trainer.train(
112
+ # resume_from_checkpoint=True
113
+ )
114
+
115
+ # 9. save log
116
+ loss_log = pd.DataFrame(dpo_trainer.state.log_history)
117
+ log_dir = './logs'
118
+ if not os.path.exists(log_dir):
119
+ os.mkdir(log_dir)
120
+ loss_log.to_csv(f"{log_dir}/dpo_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
121
+
122
+ # 10. 保存模型/lora
123
+ suffixe = '/lora/' if peft_config is not None else '/dpo'
124
+ model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe
125
+
126
+ dpo_trainer.save_model(model_save_dir)
127
+ print('save model or lora adapter to: {}'.format(model_save_dir))
128
+
129
+ def merge_lora_weight_into_model(config: DpoConfig, peft_config: LoraConfig) -> None:
130
+
131
+ # step 1. 加载tokenizer
132
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
133
+
134
+ # step 2. 加载预训练模型
135
+ sft_model = None
136
+ if os.path.isdir(config.sft_model_file):
137
+ # 传入文件夹则 from_pretrained
138
+ sft_model = TextToTextModel.from_pretrained(config.sft_model_file)
139
+ else:
140
+ # load_state_dict
141
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
142
+ sft_model = TextToTextModel(t5_config)
143
+ sft_model.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception
144
+
145
+ # 注意这个路径要和上面的model_save_dir一致
146
+ # train_dpo函数代码
147
+ # 9. 保存模型/lora
148
+ # suffixe = '/lora/' if peft_config is not None else '/dpo'
149
+ # model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe
150
+
151
+ adapter_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + '/lora'
152
+
153
+ peft_model = PeftModel.from_pretrained(
154
+ model=sft_model,
155
+ model_id=adapter_save_dir,
156
+ config=peft_config,
157
+ adapter_name='adapter',
158
+ )
159
+
160
+ # peft_model = PeftModel(
161
+ # model=sft_model,
162
+ # peft_config=peft_config,
163
+ # adapter_name='adapter',
164
+ # )
165
+
166
+ # 3. load adapter
167
+
168
+ print('load adapter from dir: {}'.format(adapter_save_dir))
169
+
170
+ peft_model.load_adapter(model_id=adapter_save_dir, adapter_name='adapter',)
171
+
172
+ # 4. merge
173
+ peft_model = peft_model.merge_and_unload()
174
+
175
+ # 5. save
176
+ save_merge_file = config.sft_model_file + '.dpo_lora_merged'
177
+ sft_model.save_pretrained(save_merge_file)
178
+ print('save merge model file to: {}'.format(save_merge_file))
179
+
180
+
181
+ if __name__ == "__main__":
182
+
183
+ peft_config = LoraConfig(
184
+ task_type=TaskType.SEQ_2_SEQ_LM, # text 2 text lora model
185
+ inference_mode=False,
186
+ r=16,
187
+ lora_alpha=16,
188
+ lora_dropout=0.1,
189
+ bias="all",
190
+ )
191
+
192
+ dpo_config = DpoConfig()
193
+
194
+ # 1. train
195
+ train_dpo(dpo_config, peft_config=None)
196
+
197
+ # 2. merge lora adapter into model
198
+ # merge_lora_weight_into_model(dpo_config, peft_config)
199
+
200
+
201
+
202
+
203
+
eval/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ceval-exam
2
+ data
3
+ result
4
+ CMMLU
5
+ result_0_shot
eval/c_eavl.ipynb ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## 下载c-eavl数据集\n",
8
+ "\n",
9
+ "```bash\n",
10
+ "mkdir ceval-data\n",
11
+ "cd ceval-data\n",
12
+ "wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip \n",
13
+ "unzip ceval-exam.zip -d ceval-exam\n",
14
+ "wget https://github.com/hkust-nlp/ceval/blob/main/subject_mapping.json\n",
15
+ "```"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": 1,
21
+ "metadata": {},
22
+ "outputs": [
23
+ {
24
+ "name": "stdout",
25
+ "output_type": "stream",
26
+ "text": [
27
+ "dev\n",
28
+ "subject_mapping.json\n",
29
+ "test\n",
30
+ "val\n"
31
+ ]
32
+ }
33
+ ],
34
+ "source": [
35
+ "! ls ceval-exam"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 2,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "import os, re\n",
45
+ "import ujson\n",
46
+ "import torch\n",
47
+ "import pandas as pd\n",
48
+ "from tqdm import tqdm\n",
49
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
50
+ "from transformers.generation.configuration_utils import GenerationConfig\n",
51
+ "from transformers.generation.utils import LogitsProcessorList, InfNanRemoveLogitsProcessor"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 3,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "ceval_dir = './ceval-exam'\n",
61
+ "result_save_dir = './result'\n",
62
+ "model_dir = '../model_save/dpo' # 模型文件在上一层目录,使用dpo后的模型\n",
63
+ "\n",
64
+ "if not os.path.exists(result_save_dir):\n",
65
+ " os.mkdir(result_save_dir)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 4,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "subject_files = os.listdir(f\"{ceval_dir}/val\")\n",
75
+ "subjects = [subjetc.replace('_val.csv', '') for subjetc in subject_files]\n",
76
+ "\n",
77
+ "subject_mapping = {}\n",
78
+ "with open('./ceval-exam/subject_mapping.json', 'r', encoding='utf-8') as f:\n",
79
+ " subject_mapping = ujson.load(f)"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {},
85
+ "source": [
86
+ "由于本项目的模型在sft阶段删除了很多带input的数据,且没有针对问题回答做微调,直接输入问题会解释问题中提到的关键词。所以c-eval测试使用预测 'A'、'B'、'C'、'D' token的方式。\n",
87
+ "> 然而有时候,特别是零样本测试和面对没有做过指令微调的模型时,模型可能无法很好的理解指令,甚至有时不会回答问题。这种情况下我们推荐直接计算下一个预测token等于\"A\", \"B\", \"C\", \"D\"的概率,然后以概率最大的选项作为答案 \n",
88
+ "> -- 这是一种受限解码生成的方法,MMLU的官方测试代码中是使用了这种方法进行测试。注意这种概率方法对思维链的测试不适用。\n",
89
+ "\n",
90
+ "见: [如何在C-Eval上测试](https://github.com/hkust-nlp/ceval/blob/main/README_zh.md#如何在C-Eval上测试)\n",
91
+ "\n",
92
+ "评测模式:zero-shot模式(chatbot/对话机器人模式) \n",
93
+ "dev数据集用来做few-shot,暂时不用"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 5,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "def format_prompt(df: pd.Series) -> str:\n",
103
+ " '''\n",
104
+ " 将df中的 'question', 'A', 'B', 'C', 'D',格式化为问题\n",
105
+ " '''\n",
106
+ " prompt = f\"请回答单选题,回答字母A、B、C、D即可。问题:\\n{df['question']}\\n答案选项:\\n\"\n",
107
+ " for col in ['A', 'B', 'C', 'D']:\n",
108
+ " prompt += f\"{col}:{df[col]}\\n\"\n",
109
+ " \n",
110
+ " return prompt"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 6,
116
+ "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "data": {
120
+ "text/plain": [
121
+ "['Accountant', '注册会计师', 'Other']"
122
+ ]
123
+ },
124
+ "execution_count": 6,
125
+ "metadata": {},
126
+ "output_type": "execute_result"
127
+ }
128
+ ],
129
+ "source": [
130
+ "subject_mapping['accountant']"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 7,
136
+ "metadata": {},
137
+ "outputs": [
138
+ {
139
+ "name": "stderr",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "100%|██████████| 52/52 [00:00<00:00, 617.74it/s]\n"
143
+ ]
144
+ }
145
+ ],
146
+ "source": [
147
+ "do_test = False\n",
148
+ "all_eval_items = []\n",
149
+ "for i, subject_name in tqdm(enumerate(subjects), total=len(subjects)):\n",
150
+ " val_file = f\"{ceval_dir}/val/{subject_name}_val.csv\"\n",
151
+ " test_file = f\"{ceval_dir}/test/{subject_name}_test.csv\"\n",
152
+ "\n",
153
+ " val_df = pd.read_csv(test_file) if do_test else pd.read_csv(val_file)\n",
154
+ " \n",
155
+ " for idx, row in val_df.iterrows():\n",
156
+ " quesuton = format_prompt(row)\n",
157
+ " answer = row['answer'] if 'answer' in val_df.columns else '' \n",
158
+ "\n",
159
+ " item = {\n",
160
+ " 'subject_en': subject_mapping[subject_name][0],\n",
161
+ " 'subject_zh': subject_mapping[subject_name][1],\n",
162
+ " 'category': subject_mapping[subject_name][2], # 类别(STEM,Social Science,Humanities,Other四选一)\n",
163
+ " 'question': quesuton,\n",
164
+ " 'answer':answer,\n",
165
+ " }\n",
166
+ " \n",
167
+ " all_eval_items.append(item)"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 8,
173
+ "metadata": {},
174
+ "outputs": [
175
+ {
176
+ "data": {
177
+ "text/html": [
178
+ "<div>\n",
179
+ "<style scoped>\n",
180
+ " .dataframe tbody tr th:only-of-type {\n",
181
+ " vertical-align: middle;\n",
182
+ " }\n",
183
+ "\n",
184
+ " .dataframe tbody tr th {\n",
185
+ " vertical-align: top;\n",
186
+ " }\n",
187
+ "\n",
188
+ " .dataframe thead th {\n",
189
+ " text-align: right;\n",
190
+ " }\n",
191
+ "</style>\n",
192
+ "<table border=\"1\" class=\"dataframe\">\n",
193
+ " <thead>\n",
194
+ " <tr style=\"text-align: right;\">\n",
195
+ " <th></th>\n",
196
+ " <th>subject_en</th>\n",
197
+ " <th>subject_zh</th>\n",
198
+ " <th>category</th>\n",
199
+ " <th>question</th>\n",
200
+ " <th>answer</th>\n",
201
+ " </tr>\n",
202
+ " </thead>\n",
203
+ " <tbody>\n",
204
+ " <tr>\n",
205
+ " <th>0</th>\n",
206
+ " <td>Accountant</td>\n",
207
+ " <td>注册会计师</td>\n",
208
+ " <td>Other</td>\n",
209
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是...</td>\n",
210
+ " <td>D</td>\n",
211
+ " </tr>\n",
212
+ " <tr>\n",
213
+ " <th>1</th>\n",
214
+ " <td>Accountant</td>\n",
215
+ " <td>注册会计师</td>\n",
216
+ " <td>Other</td>\n",
217
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动...</td>\n",
218
+ " <td>C</td>\n",
219
+ " </tr>\n",
220
+ " <tr>\n",
221
+ " <th>2</th>\n",
222
+ " <td>Accountant</td>\n",
223
+ " <td>注册会计师</td>\n",
224
+ " <td>Other</td>\n",
225
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各...</td>\n",
226
+ " <td>D</td>\n",
227
+ " </tr>\n",
228
+ " <tr>\n",
229
+ " <th>3</th>\n",
230
+ " <td>Accountant</td>\n",
231
+ " <td>注册会计师</td>\n",
232
+ " <td>Other</td>\n",
233
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n税务行政复议的申请人可以在得知税务机关作...</td>\n",
234
+ " <td>A</td>\n",
235
+ " </tr>\n",
236
+ " <tr>\n",
237
+ " <th>4</th>\n",
238
+ " <td>Accountant</td>\n",
239
+ " <td>注册会计师</td>\n",
240
+ " <td>Other</td>\n",
241
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n关于战略管理表述错误的是____。\\n答...</td>\n",
242
+ " <td>C</td>\n",
243
+ " </tr>\n",
244
+ " </tbody>\n",
245
+ "</table>\n",
246
+ "</div>"
247
+ ],
248
+ "text/plain": [
249
+ " subject_en subject_zh category \\\n",
250
+ "0 Accountant 注册会计师 Other \n",
251
+ "1 Accountant 注册会计师 Other \n",
252
+ "2 Accountant 注册会计师 Other \n",
253
+ "3 Accountant 注册会计师 Other \n",
254
+ "4 Accountant 注册会计师 Other \n",
255
+ "\n",
256
+ " question answer \n",
257
+ "0 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... D \n",
258
+ "1 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... C \n",
259
+ "2 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... D \n",
260
+ "3 请回答单选题,回答字母A、B、C、D即可。问题:\\n税务行政复议的申请人可以在得知税务机关作... A \n",
261
+ "4 请回答单选题,回答字母A、B、C、D即可。问题:\\n关于战略管理表述错误的是____。\\n答... C "
262
+ ]
263
+ },
264
+ "execution_count": 8,
265
+ "metadata": {},
266
+ "output_type": "execute_result"
267
+ }
268
+ ],
269
+ "source": [
270
+ "eval_df = pd.DataFrame(all_eval_items)\n",
271
+ "eval_df.head(5)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": 9,
277
+ "metadata": {},
278
+ "outputs": [
279
+ {
280
+ "data": {
281
+ "text/plain": [
282
+ "[872, 873, 884, 886]"
283
+ ]
284
+ },
285
+ "execution_count": 9,
286
+ "metadata": {},
287
+ "output_type": "execute_result"
288
+ }
289
+ ],
290
+ "source": [
291
+ "# 加载模型\n",
292
+ "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
293
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)\n",
294
+ "\n",
295
+ "generation_config = GenerationConfig()\n",
296
+ "generation_config.remove_invalid_values = True # 自动添加InfNanRemoveLogitsProcessor\n",
297
+ "generation_config.eos_token_id = tokenizer.eos_token_id\n",
298
+ "generation_config.pad_token_id = tokenizer.pad_token_id\n",
299
+ "# for t5, set decoder_start_token_id = pad_token_id\n",
300
+ "generation_config.decoder_start_token_id = tokenizer.pad_token_id \n",
301
+ "generation_config.max_new_tokens = 16\n",
302
+ "generation_config.num_beams = 1\n",
303
+ "generation_config.do_sample = False # greedy search\n",
304
+ "\n",
305
+ "choices = ['A', 'B', 'C', 'D']\n",
306
+ "choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n",
307
+ "choices_ids"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 10,
313
+ "metadata": {},
314
+ "outputs": [
315
+ {
316
+ "name": "stderr",
317
+ "output_type": "stream",
318
+ "text": [
319
+ "100%|██████████| 1346/1346 [00:20<00:00, 64.11it/s]\n"
320
+ ]
321
+ }
322
+ ],
323
+ "source": [
324
+ "batch_size = 32\n",
325
+ "batch_data, batch_answers = [], []\n",
326
+ "n = len(eval_df)\n",
327
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
328
+ "model.to(device)\n",
329
+ "model.eval()\n",
330
+ "\n",
331
+ "for idx, row in tqdm(eval_df.iterrows(), total=n):\n",
332
+ " batch_data.append(row['question'])\n",
333
+ " \n",
334
+ " if len(batch_data) == batch_size or idx == n - 1:\n",
335
+ " torch.cuda.empty_cache()\n",
336
+ " \n",
337
+ " encode_ids = tokenizer(batch_data, padding=True)\n",
338
+ " input_ids, attention_mask = torch.LongTensor(encode_ids['input_ids']), torch.LongTensor(encode_ids['attention_mask'])\n",
339
+ " \n",
340
+ " outputs = model.generate(\n",
341
+ " input_ids=input_ids.to(device),\n",
342
+ " attention_mask=attention_mask.to(device),\n",
343
+ " generation_config=generation_config,\n",
344
+ " return_dict_in_generate=True,\n",
345
+ " output_scores=True,\n",
346
+ " )\n",
347
+ "\n",
348
+ " scores = torch.stack(outputs['scores'], dim=1)\n",
349
+ " scores = torch.softmax(scores, dim=2)\n",
350
+ " scores = scores[..., 0, choices_ids] #取第一个字符的ABCD概率\n",
351
+ " choices_index = torch.argmax(scores, dim=1)\n",
352
+ " \n",
353
+ " for i in choices_index:\n",
354
+ " batch_answers.append(choices[i])\n",
355
+ "\n",
356
+ " batch_data = []"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": 11,
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "eval_df.insert(loc=5, column='model_predict', value=batch_answers)\n",
366
+ "val_df = eval_df.copy(deep=True)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": 12,
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "val_df['is_correct'] = val_df['model_predict'] == val_df['answer']\n",
376
+ "val_df['is_correct'] = val_df['is_correct'].astype(pd.Int16Dtype())"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": 13,
382
+ "metadata": {},
383
+ "outputs": [
384
+ {
385
+ "data": {
386
+ "text/html": [
387
+ "<div>\n",
388
+ "<style scoped>\n",
389
+ " .dataframe tbody tr th:only-of-type {\n",
390
+ " vertical-align: middle;\n",
391
+ " }\n",
392
+ "\n",
393
+ " .dataframe tbody tr th {\n",
394
+ " vertical-align: top;\n",
395
+ " }\n",
396
+ "\n",
397
+ " .dataframe thead th {\n",
398
+ " text-align: right;\n",
399
+ " }\n",
400
+ "</style>\n",
401
+ "<table border=\"1\" class=\"dataframe\">\n",
402
+ " <thead>\n",
403
+ " <tr style=\"text-align: right;\">\n",
404
+ " <th></th>\n",
405
+ " <th>subject_en</th>\n",
406
+ " <th>subject_zh</th>\n",
407
+ " <th>category</th>\n",
408
+ " <th>question</th>\n",
409
+ " <th>answer</th>\n",
410
+ " <th>model_predict</th>\n",
411
+ " <th>is_correct</th>\n",
412
+ " </tr>\n",
413
+ " </thead>\n",
414
+ " <tbody>\n",
415
+ " <tr>\n",
416
+ " <th>0</th>\n",
417
+ " <td>Accountant</td>\n",
418
+ " <td>注册会计师</td>\n",
419
+ " <td>Other</td>\n",
420
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是...</td>\n",
421
+ " <td>D</td>\n",
422
+ " <td>A</td>\n",
423
+ " <td>0</td>\n",
424
+ " </tr>\n",
425
+ " <tr>\n",
426
+ " <th>1</th>\n",
427
+ " <td>Accountant</td>\n",
428
+ " <td>注册会计师</td>\n",
429
+ " <td>Other</td>\n",
430
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动...</td>\n",
431
+ " <td>C</td>\n",
432
+ " <td>A</td>\n",
433
+ " <td>0</td>\n",
434
+ " </tr>\n",
435
+ " <tr>\n",
436
+ " <th>2</th>\n",
437
+ " <td>Accountant</td>\n",
438
+ " <td>注册会计师</td>\n",
439
+ " <td>Other</td>\n",
440
+ " <td>请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各...</td>\n",
441
+ " <td>D</td>\n",
442
+ " <td>A</td>\n",
443
+ " <td>0</td>\n",
444
+ " </tr>\n",
445
+ " </tbody>\n",
446
+ "</table>\n",
447
+ "</div>"
448
+ ],
449
+ "text/plain": [
450
+ " subject_en subject_zh category \\\n",
451
+ "0 Accountant 注册会计师 Other \n",
452
+ "1 Accountant 注册会计师 Other \n",
453
+ "2 Accountant 注册会计师 Other \n",
454
+ "\n",
455
+ " question answer model_predict \\\n",
456
+ "0 请回答单选题,回答字母A、B、C、D即可。问题:\\n下列关于税法基本原则的表述中,不正确的是... D A \n",
457
+ "1 请回答单选题,回答字母A、B、C、D即可。问题:\\n甲公司是国内一家领先的新媒体、通信及移动... C A \n",
458
+ "2 请回答单选题,回答字母A、B、C、D即可。问题:\\n根据我国《印花税暂行条例》的规定,下列各... D A \n",
459
+ "\n",
460
+ " is_correct \n",
461
+ "0 0 \n",
462
+ "1 0 \n",
463
+ "2 0 "
464
+ ]
465
+ },
466
+ "execution_count": 13,
467
+ "metadata": {},
468
+ "output_type": "execute_result"
469
+ }
470
+ ],
471
+ "source": [
472
+ "val_df.head(3)"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": 14,
478
+ "metadata": {},
479
+ "outputs": [
480
+ {
481
+ "data": {
482
+ "text/html": [
483
+ "<div>\n",
484
+ "<style scoped>\n",
485
+ " .dataframe tbody tr th:only-of-type {\n",
486
+ " vertical-align: middle;\n",
487
+ " }\n",
488
+ "\n",
489
+ " .dataframe tbody tr th {\n",
490
+ " vertical-align: top;\n",
491
+ " }\n",
492
+ "\n",
493
+ " .dataframe thead th {\n",
494
+ " text-align: right;\n",
495
+ " }\n",
496
+ "</style>\n",
497
+ "<table border=\"1\" class=\"dataframe\">\n",
498
+ " <thead>\n",
499
+ " <tr style=\"text-align: right;\">\n",
500
+ " <th></th>\n",
501
+ " <th>is_correct</th>\n",
502
+ " </tr>\n",
503
+ " <tr>\n",
504
+ " <th>category</th>\n",
505
+ " <th></th>\n",
506
+ " </tr>\n",
507
+ " </thead>\n",
508
+ " <tbody>\n",
509
+ " <tr>\n",
510
+ " <th>Humanities</th>\n",
511
+ " <td>63</td>\n",
512
+ " </tr>\n",
513
+ " <tr>\n",
514
+ " <th>Other</th>\n",
515
+ " <td>89</td>\n",
516
+ " </tr>\n",
517
+ " <tr>\n",
518
+ " <th>STEM</th>\n",
519
+ " <td>89</td>\n",
520
+ " </tr>\n",
521
+ " <tr>\n",
522
+ " <th>Social Science</th>\n",
523
+ " <td>72</td>\n",
524
+ " </tr>\n",
525
+ " </tbody>\n",
526
+ "</table>\n",
527
+ "</div>"
528
+ ],
529
+ "text/plain": [
530
+ " is_correct\n",
531
+ "category \n",
532
+ "Humanities 63\n",
533
+ "Other 89\n",
534
+ "STEM 89\n",
535
+ "Social Science 72"
536
+ ]
537
+ },
538
+ "execution_count": 14,
539
+ "metadata": {},
540
+ "output_type": "execute_result"
541
+ }
542
+ ],
543
+ "source": [
544
+ "final_df = val_df.groupby('category').sum('is_correct')\n",
545
+ "final_df"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": 15,
551
+ "metadata": {},
552
+ "outputs": [
553
+ {
554
+ "data": {
555
+ "text/html": [
556
+ "<div>\n",
557
+ "<style scoped>\n",
558
+ " .dataframe tbody tr th:only-of-type {\n",
559
+ " vertical-align: middle;\n",
560
+ " }\n",
561
+ "\n",
562
+ " .dataframe tbody tr th {\n",
563
+ " vertical-align: top;\n",
564
+ " }\n",
565
+ "\n",
566
+ " .dataframe thead th {\n",
567
+ " text-align: right;\n",
568
+ " }\n",
569
+ "</style>\n",
570
+ "<table border=\"1\" class=\"dataframe\">\n",
571
+ " <thead>\n",
572
+ " <tr style=\"text-align: right;\">\n",
573
+ " <th></th>\n",
574
+ " <th>is_correct</th>\n",
575
+ " <th>question_count</th>\n",
576
+ " <th>accuracy</th>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <th>category</th>\n",
580
+ " <th></th>\n",
581
+ " <th></th>\n",
582
+ " <th></th>\n",
583
+ " </tr>\n",
584
+ " </thead>\n",
585
+ " <tbody>\n",
586
+ " <tr>\n",
587
+ " <th>Humanities</th>\n",
588
+ " <td>63</td>\n",
589
+ " <td>257</td>\n",
590
+ " <td>24.51%</td>\n",
591
+ " </tr>\n",
592
+ " <tr>\n",
593
+ " <th>Other</th>\n",
594
+ " <td>89</td>\n",
595
+ " <td>384</td>\n",
596
+ " <td>23.18%</td>\n",
597
+ " </tr>\n",
598
+ " <tr>\n",
599
+ " <th>STEM</th>\n",
600
+ " <td>89</td>\n",
601
+ " <td>430</td>\n",
602
+ " <td>20.70%</td>\n",
603
+ " </tr>\n",
604
+ " <tr>\n",
605
+ " <th>Social Science</th>\n",
606
+ " <td>72</td>\n",
607
+ " <td>275</td>\n",
608
+ " <td>26.18%</td>\n",
609
+ " </tr>\n",
610
+ " </tbody>\n",
611
+ "</table>\n",
612
+ "</div>"
613
+ ],
614
+ "text/plain": [
615
+ " is_correct question_count accuracy\n",
616
+ "category \n",
617
+ "Humanities 63 257 24.51%\n",
618
+ "Other 89 384 23.18%\n",
619
+ "STEM 89 430 20.70%\n",
620
+ "Social Science 72 275 26.18%"
621
+ ]
622
+ },
623
+ "execution_count": 15,
624
+ "metadata": {},
625
+ "output_type": "execute_result"
626
+ }
627
+ ],
628
+ "source": [
629
+ "final_df['question_count'] = val_df.groupby('category').count()['question']\n",
630
+ "final_df['accuracy'] = final_df['is_correct'] / final_df['question_count']\n",
631
+ "final_df['accuracy'] = final_df['accuracy'] .apply(lambda x: format(x, '.2%'))\n",
632
+ "final_df"
633
+ ]
634
+ }
635
+ ],
636
+ "metadata": {
637
+ "kernelspec": {
638
+ "display_name": "py310",
639
+ "language": "python",
640
+ "name": "python3"
641
+ },
642
+ "language_info": {
643
+ "codemirror_mode": {
644
+ "name": "ipython",
645
+ "version": 3
646
+ },
647
+ "file_extension": ".py",
648
+ "mimetype": "text/x-python",
649
+ "name": "python",
650
+ "nbconvert_exporter": "python",
651
+ "pygments_lexer": "ipython3",
652
+ "version": "3.10.12"
653
+ }
654
+ },
655
+ "nbformat": 4,
656
+ "nbformat_minor": 2
657
+ }
eval/cmmlu.ipynb ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import torch\n",
11
+ "import numpy as np\n",
12
+ "import sys\n",
13
+ "root = '/'.join(os.path.realpath('.').replace('\\\\','/').split('/'))\n",
14
+ "p = root + '/CMMLU/src'\n",
15
+ "if p not in sys.path:\n",
16
+ " sys.path.append(p)\n",
17
+ "import argparse\n",
18
+ "from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval\n",
19
+ "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
20
+ "from transformers.generation.configuration_utils import GenerationConfig"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "metadata": {},
26
+ "source": [
27
+ "```bash\n",
28
+ "git clone -- depth 1 https://github.com/haonan-li/CMMLU.git\n",
29
+ "```\n",
30
+ "\n",
31
+ "cpoied from https://github.com/haonan-li/CMMLU/blob/master/src/hf_causal_model.py"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "model_dir = '../model_save/dpo' # 模型文件在上一层目录,使用dpo后的模型\n",
41
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
42
+ "# 加载模型\n",
43
+ "tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
44
+ "model = AutoModelForSeq2SeqLM.from_pretrained(model_dir).to(device)\n",
45
+ "generation_config = GenerationConfig()\n",
46
+ "generation_config.remove_invalid_values = True # 自动添加InfNanRemoveLogitsProcessor\n",
47
+ "generation_config.eos_token_id = tokenizer.eos_token_id\n",
48
+ "generation_config.pad_token_id = tokenizer.pad_token_id\n",
49
+ "# for t5, set decoder_start_token_id = pad_token_id\n",
50
+ "generation_config.decoder_start_token_id = tokenizer.pad_token_id \n",
51
+ "generation_config.max_new_tokens = 1\n",
52
+ "generation_config.num_beams = 1\n",
53
+ "generation_config.do_sample = False # greedy search\n",
54
+ "\n",
55
+ "choices = ['A', 'B', 'C', 'D']\n",
56
+ "choices_ids = [tokenizer.convert_tokens_to_ids(c) for c in choices]\n",
57
+ "choices_ids"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "def eval(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):\n",
67
+ " choice_ids = [tokenizer.convert_tokens_to_ids(choice) for choice in choices]\n",
68
+ " cors = []\n",
69
+ " all_conf = []\n",
70
+ " all_preds = []\n",
71
+ " answers = choices[: test_df.shape[1] - 2]\n",
72
+ "\n",
73
+ " for i in range(test_df.shape[0]):\n",
74
+ " prompt_end = format_example(test_df, i, subject, include_answer=False)\n",
75
+ " prompt = gen_prompt(dev_df=dev_df,\n",
76
+ " subject=subject,\n",
77
+ " prompt_end=prompt_end,\n",
78
+ " num_few_shot=num_few_shot,\n",
79
+ " tokenizer=tokenizer,\n",
80
+ " max_length=max_length)\n",
81
+ " inputs = tokenizer([prompt])\n",
82
+ " if \"token_type_ids\" in inputs: # For Falcon\n",
83
+ " inputs.pop(\"token_type_ids\")\n",
84
+ " label = test_df.iloc[i, test_df.shape[1] - 1]\n",
85
+ " torch.cuda.empty_cache()\n",
86
+ " \n",
87
+ " input_ids, attention_mask = torch.LongTensor(inputs['input_ids']), torch.LongTensor(inputs['attention_mask'])\n",
88
+ " \n",
89
+ " with torch.no_grad():\n",
90
+ " outputs = model.generate(\n",
91
+ " input_ids=input_ids.to(device),\n",
92
+ " attention_mask=attention_mask.to(device),\n",
93
+ " generation_config=generation_config,\n",
94
+ " return_dict_in_generate=True,\n",
95
+ " output_scores=True,\n",
96
+ " )\n",
97
+ " \n",
98
+ " scores = torch.stack(outputs['scores'], dim=1).to('cpu')\n",
99
+ " scores = torch.softmax(scores, dim=2)\n",
100
+ " scores = scores[..., 0, choices_ids] #取第一个字符的ABCD概率\n",
101
+ " conf = scores[0][choices.index(label)]\n",
102
+ " choices_index = torch.argmax(scores)\n",
103
+ " \n",
104
+ " pred = choices[choices_index]\n",
105
+ "\n",
106
+ " all_preds += pred\n",
107
+ " all_conf.append(conf)\n",
108
+ " cors.append(pred == label)\n",
109
+ "\n",
110
+ " acc = np.mean(cors)\n",
111
+ " print(\"Average accuracy {:.3f} - {}\".format(acc, subject))\n",
112
+ " return acc, all_preds, conf"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 4,
118
+ "metadata": {},
119
+ "outputs": [
120
+ {
121
+ "name": "stdout",
122
+ "output_type": "stream",
123
+ "text": [
124
+ "Average accuracy 0.243 - agronomy\n",
125
+ "Average accuracy 0.243 - anatomy\n",
126
+ "Average accuracy 0.256 - ancient_chinese\n",
127
+ "Average accuracy 0.256 - arts\n",
128
+ "Average accuracy 0.248 - astronomy\n",
129
+ "Average accuracy 0.234 - business_ethics\n",
130
+ "Average accuracy 0.256 - chinese_civil_service_exam\n",
131
+ "Average accuracy 0.260 - chinese_driving_rule\n",
132
+ "Average accuracy 0.235 - chinese_food_culture\n",
133
+ "Average accuracy 0.252 - chinese_foreign_policy\n",
134
+ "Average accuracy 0.251 - chinese_history\n",
135
+ "Average accuracy 0.250 - chinese_literature\n",
136
+ "Average accuracy 0.246 - chinese_teacher_qualification\n",
137
+ "Average accuracy 0.253 - clinical_knowledge\n",
138
+ "Average accuracy 0.245 - college_actuarial_science\n",
139
+ "Average accuracy 0.318 - college_education\n",
140
+ "Average accuracy 0.302 - college_engineering_hydrology\n",
141
+ "Average accuracy 0.213 - college_law\n",
142
+ "Average accuracy 0.219 - college_mathematics\n",
143
+ "Average accuracy 0.264 - college_medical_statistics\n",
144
+ "Average accuracy 0.234 - college_medicine\n",
145
+ "Average accuracy 0.240 - computer_science\n",
146
+ "Average accuracy 0.263 - computer_security\n",
147
+ "Average accuracy 0.252 - conceptual_physics\n",
148
+ "Average accuracy 0.252 - construction_project_management\n",
149
+ "Average accuracy 0.239 - economics\n",
150
+ "Average accuracy 0.258 - education\n",
151
+ "Average accuracy 0.250 - electrical_engineering\n",
152
+ "Average accuracy 0.282 - elementary_chinese\n",
153
+ "Average accuracy 0.242 - elementary_commonsense\n",
154
+ "Average accuracy 0.282 - elementary_information_and_technology\n",
155
+ "Average accuracy 0.283 - elementary_mathematics\n",
156
+ "Average accuracy 0.252 - ethnology\n",
157
+ "Average accuracy 0.252 - food_science\n",
158
+ "Average accuracy 0.239 - genetics\n",
159
+ "Average accuracy 0.242 - global_facts\n",
160
+ "Average accuracy 0.272 - high_school_biology\n",
161
+ "Average accuracy 0.235 - high_school_chemistry\n",
162
+ "Average accuracy 0.271 - high_school_geography\n",
163
+ "Average accuracy 0.250 - high_school_mathematics\n",
164
+ "Average accuracy 0.255 - high_school_physics\n",
165
+ "Average accuracy 0.252 - high_school_politics\n",
166
+ "Average accuracy 0.254 - human_sexuality\n",
167
+ "Average accuracy 0.249 - international_law\n",
168
+ "Average accuracy 0.250 - journalism\n",
169
+ "Average accuracy 0.253 - jurisprudence\n",
170
+ "Average accuracy 0.252 - legal_and_moral_basis\n",
171
+ "Average accuracy 0.252 - logical\n",
172
+ "Average accuracy 0.238 - machine_learning\n",
173
+ "Average accuracy 0.243 - management\n",
174
+ "Average accuracy 0.250 - marketing\n",
175
+ "Average accuracy 0.249 - marxist_theory\n",
176
+ "Average accuracy 0.250 - modern_chinese\n",
177
+ "Average accuracy 0.241 - nutrition\n",
178
+ "Average accuracy 0.257 - philosophy\n",
179
+ "Average accuracy 0.251 - professional_accounting\n",
180
+ "Average accuracy 0.251 - professional_law\n",
181
+ "Average accuracy 0.242 - professional_medicine\n",
182
+ "Average accuracy 0.246 - professional_psychology\n",
183
+ "Average accuracy 0.247 - public_relations\n",
184
+ "Average accuracy 0.252 - security_study\n",
185
+ "Average accuracy 0.252 - sociology\n",
186
+ "Average accuracy 0.248 - sports_science\n",
187
+ "Average accuracy 0.254 - traditional_chinese_medicine\n",
188
+ "Average accuracy 0.243 - virology\n",
189
+ "Average accuracy 0.242 - world_history\n",
190
+ "Average accuracy 0.256 - world_religions\n",
191
+ "STEM 25.16\n",
192
+ "Humanities 24.78\n",
193
+ "Social Science 25.42\n",
194
+ "Other 25.15\n",
195
+ "China specific 25.26\n",
196
+ "Overall 25.17\n"
197
+ ]
198
+ }
199
+ ],
200
+ "source": [
201
+ "from dataclasses import dataclass\n",
202
+ "@dataclass\n",
203
+ "class Args:\n",
204
+ " data_dir: str = './CMMLU/data'\n",
205
+ " save_dir: str = './result'\n",
206
+ " num_few_shot: int = 0\n",
207
+ " max_length: int = 512\n",
208
+ "\n",
209
+ "run_eval(model, tokenizer, eval, Args())"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": []
218
+ }
219
+ ],
220
+ "metadata": {
221
+ "kernelspec": {
222
+ "display_name": "py310",
223
+ "language": "python",
224
+ "name": "python3"
225
+ },
226
+ "language_info": {
227
+ "codemirror_mode": {
228
+ "name": "ipython",
229
+ "version": 3
230
+ },
231
+ "file_extension": ".py",
232
+ "mimetype": "text/x-python",
233
+ "name": "python",
234
+ "nbconvert_exporter": "python",
235
+ "pygments_lexer": "ipython3",
236
+ "version": "3.10.12"
237
+ }
238
+ },
239
+ "nbformat": 4,
240
+ "nbformat_minor": 2
241
+ }
finetune_examples/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ data
2
+ model_save
3
+ logs
finetune_examples/info_extract/data_process.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ujson
2
+ import codecs
3
+ import re
4
+ from rich import progress
5
+ import numpy as np
6
+
7
+
8
+ def process_all_50_schemas(raw_schemas_file: str='./data/all_50_schemas', save_schemas_file: str=None) -> list[str]:
9
+ '''
10
+ 获取prompt的关系列表
11
+ '''
12
+ lines = []
13
+ with codecs.open(raw_schemas_file, 'r', encoding='utf-8') as f:
14
+ lines = f.readlines()
15
+
16
+ scheme_list = []
17
+ for line in lines:
18
+ item = ujson.loads(line)
19
+ scheme_list.append(
20
+ item['predicate']
21
+ )
22
+
23
+ scheme_list = list(set(scheme_list))
24
+
25
+ if save_schemas_file:
26
+ with codecs.open(save_schemas_file, 'w', encoding='utf-8') as f:
27
+ ujson.dump(f"{scheme_list}", f, indent=4, ensure_ascii=False)
28
+
29
+ return scheme_list
30
+
31
+ def process_spo_list(text: str, spo_list: list, repair_song: bool=False):
32
+ '''
33
+ 处理spo_list,处理成{subject: 'subject', subject_start: 0, subject_end:3, predicate: 'predicate', object: 'object', object_start: 5, object_end = 7}
34
+ '''
35
+ new_spo_list = []
36
+
37
+ # 找出所有用书名号隔开的名字
38
+ some_name = re.findall('《([^《》]*?)》', text)
39
+ some_name = [n.strip() for n in some_name]
40
+
41
+ # 歌曲和专辑
42
+ song = []
43
+ album = []
44
+ for spo in spo_list:
45
+
46
+ # 修正so的错误,删除前后的书名号
47
+ s = spo['subject'].strip('《》').strip().lower()
48
+ o = spo['object'].strip('《》').strip().lower()
49
+ p = spo['predicate']
50
+
51
+ # 如果s在找到的名字中,以正则找到的s为准,用in判等,
52
+ # 如text: '《造梦者---dreamer》',但是标注的s是'造梦者'
53
+ for name in some_name:
54
+ if s in name and text.count(s) == 1:
55
+ s = name
56
+
57
+ if repair_song:
58
+ if p == '所属专辑':
59
+ song.append(s)
60
+ album.append(o)
61
+
62
+ temp = dict()
63
+ temp['s'] = s
64
+ temp['p'] = spo['predicate']
65
+ temp['o'] = o
66
+
67
+
68
+ # 在text中找不到subject 或者 object,不要这条数据了
69
+ if text.find(s) == -1 or text.find(o) == -1:
70
+ continue
71
+
72
+ new_spo_list.append(temp)
73
+
74
+ if repair_song:
75
+ ret_spo_list = []
76
+ ps = ['歌手', '作词', '作曲']
77
+
78
+ for spo in new_spo_list:
79
+ s, p, o = spo['s'], spo['p'], spo['o']
80
+ if p in ps and s in album and s not in song:
81
+ continue
82
+ ret_spo_list.append(spo)
83
+
84
+ return ret_spo_list
85
+
86
+ return new_spo_list
87
+
88
+
89
+ def process_data(raw_data_file: str, train_file_name: str, dev_file_name: str, keep_max_length: int=512, repair_song: bool=True, dev_size: int=1000) -> None:
90
+ '''
91
+ 将原始的格式处理为prompt:resopnse的格式
92
+ '''
93
+ lines = []
94
+ with codecs.open(raw_data_file, 'r', encoding='utf-8') as f:
95
+ lines = f.readlines()
96
+ my_raw_data = []
97
+
98
+ schemas = process_all_50_schemas('./data/all_50_schemas')
99
+ schemas = f"[{','.join(schemas)}]"
100
+ for i, line in progress.track(enumerate(lines), total=len(lines)):
101
+
102
+ tmp = ujson.decode(line)
103
+ text = f"请抽取出给定句子中的所有三元组。给定句子:{tmp['text'].lower()}"
104
+
105
+ spo_list = process_spo_list(tmp['text'].lower(), tmp['spo_list'], repair_song=repair_song)
106
+ spo = f"{[(item['s'], item['p'], item['o']) for item in spo_list]}"
107
+ # 删除长度过长、没有找到实体信息的句子
108
+ if len(text) > keep_max_length or len(spo) > keep_max_length or len(spo_list) == 0:
109
+ continue
110
+
111
+ my_raw_data.append({
112
+ 'prompt': text,
113
+ 'response':spo.replace('\'','').replace(' ', ''),
114
+ })
115
+
116
+
117
+ dev_date = []
118
+ if dev_file_name is not None:
119
+ dev_index = np.random.choice(range(0, len(my_raw_data)), size=dev_size, replace=False)
120
+ dev_index = set(dev_index)
121
+ assert len(dev_index) == dev_size
122
+
123
+ train_data = [x for i, x in enumerate(my_raw_data) if i not in dev_index]
124
+ dev_date = [x for i, x in enumerate(my_raw_data) if i in dev_index]
125
+
126
+ with codecs.open(dev_file_name, 'w', encoding='utf-8') as f:
127
+ ujson.dump(dev_date, f, indent=4, ensure_ascii=False)
128
+
129
+ my_raw_data = train_data
130
+
131
+ print(f'length of train data {len(my_raw_data)}, length of eval data {len(dev_date)}')
132
+
133
+ with codecs.open(train_file_name, 'w', encoding='utf-8') as f:
134
+ ujson.dump(my_raw_data, f, indent=4, ensure_ascii=False)
135
+
136
+ if __name__ == '__main__':
137
+ raw_data_file = './data/train_data.json'
138
+ train_file = './data/my_train.json'
139
+ dev_file = './data/my_eval.json'
140
+
141
+ process_all_50_schemas('./data/all_50_schemas', './data/my_schemas.txt')
142
+
143
+ process_data(raw_data_file, train_file, dev_file, keep_max_length=512, dev_size=1000)
144
+
145
+ # 使用该数据集公开的dev_data作为测试集
146
+ process_data('./data/dev_data.json', train_file_name='./data/test.json', dev_file_name=None, keep_max_length=512, dev_size=1000)
finetune_examples/info_extract/finetune_IE_task.ipynb ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# coding=utf-8\n",
10
+ "from typing import Dict\n",
11
+ "import time \n",
12
+ "import pandas as pd \n",
13
+ "\n",
14
+ "import torch\n",
15
+ "from datasets import Dataset, load_dataset\n",
16
+ "from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments\n",
17
+ "from transformers.generation.configuration_utils import GenerationConfig"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 2,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import sys, os\n",
27
+ "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -2]\n",
28
+ "root = '/'.join(root)\n",
29
+ "if root not in sys.path:\n",
30
+ " sys.path.append(root)\n",
31
+ "\n",
32
+ "from model.chat_model import TextToTextModel\n",
33
+ "from config import SFTconfig, InferConfig, T5ModelConfig\n",
34
+ "from utils.functions import get_T5_config\n",
35
+ "\n",
36
+ "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "def get_dataset(file: str, split: str, encode_fn: callable, encode_args: dict, cache_dir: str='.cache') -> Dataset:\n",
46
+ " \"\"\"\n",
47
+ " Load a dataset\n",
48
+ " \"\"\"\n",
49
+ " dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)\n",
50
+ "\n",
51
+ " def merge_prompt_and_responses(sample: dict) -> Dict[str, str]:\n",
52
+ " # add an eos token note that end of sentence, using in generate.\n",
53
+ " prompt = encode_fn(f\"{sample['prompt']}[EOS]\", **encode_args)\n",
54
+ " response = encode_fn(f\"{sample['response']}[EOS]\", **encode_args)\n",
55
+ " return {\n",
56
+ " 'input_ids': prompt.input_ids,\n",
57
+ " 'labels': response.input_ids,\n",
58
+ " }\n",
59
+ "\n",
60
+ " dataset = dataset.map(merge_prompt_and_responses)\n",
61
+ " return dataset"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 4,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "def sft_train(config: SFTconfig) -> None:\n",
71
+ "\n",
72
+ " # step 1. 加载tokenizer\n",
73
+ " tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)\n",
74
+ " \n",
75
+ " # step 2. 加载预训练模型\n",
76
+ " model = None\n",
77
+ " if os.path.isdir(config.finetune_from_ckp_file):\n",
78
+ " # 传入文件夹则 from_pretrained\n",
79
+ " model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)\n",
80
+ " else:\n",
81
+ " # load_state_dict\n",
82
+ " t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)\n",
83
+ " model = TextToTextModel(t5_config)\n",
84
+ " model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception\n",
85
+ " \n",
86
+ " # Step 4: Load the dataset\n",
87
+ " encode_args = {\n",
88
+ " 'truncation': False,\n",
89
+ " 'padding': 'max_length',\n",
90
+ " }\n",
91
+ "\n",
92
+ " dataset = get_dataset(file=config.sft_train_file, encode_fn=tokenizer.encode_plus, encode_args=encode_args, split=\"train\")\n",
93
+ "\n",
94
+ " # Step 5: Define the training arguments\n",
95
+ " # T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer\n",
96
+ " # huggingface官网的sft工具适用于language model/LM模型\n",
97
+ " generation_config = GenerationConfig()\n",
98
+ " generation_config.remove_invalid_values = True\n",
99
+ " generation_config.eos_token_id = tokenizer.eos_token_id\n",
100
+ " generation_config.pad_token_id = tokenizer.pad_token_id\n",
101
+ " generation_config.decoder_start_token_id = tokenizer.pad_token_id\n",
102
+ " generation_config.max_new_tokens = 320\n",
103
+ " generation_config.repetition_penalty = 1.5\n",
104
+ " generation_config.num_beams = 1 # greedy search\n",
105
+ " generation_config.do_sample = False # greedy search\n",
106
+ "\n",
107
+ " training_args = Seq2SeqTrainingArguments(\n",
108
+ " output_dir=config.output_dir,\n",
109
+ " per_device_train_batch_size=config.batch_size,\n",
110
+ " auto_find_batch_size=True, # 防止OOM\n",
111
+ " gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
112
+ " learning_rate=config.learning_rate,\n",
113
+ " logging_steps=config.logging_steps,\n",
114
+ " num_train_epochs=config.num_train_epochs,\n",
115
+ " optim=\"adafactor\",\n",
116
+ " report_to='tensorboard',\n",
117
+ " log_level='info',\n",
118
+ " save_steps=config.save_steps,\n",
119
+ " save_total_limit=3,\n",
120
+ " fp16=config.fp16,\n",
121
+ " logging_first_step=config.logging_first_step,\n",
122
+ " warmup_steps=config.warmup_steps,\n",
123
+ " seed=config.seed,\n",
124
+ " generation_config=generation_config,\n",
125
+ " )\n",
126
+ "\n",
127
+ " # step 6: init a collator\n",
128
+ " collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)\n",
129
+ " \n",
130
+ " # Step 7: Define the Trainer\n",
131
+ " trainer = Seq2SeqTrainer(\n",
132
+ " model=model,\n",
133
+ " args=training_args,\n",
134
+ " train_dataset=dataset,\n",
135
+ " eval_dataset=dataset,\n",
136
+ " tokenizer=tokenizer,\n",
137
+ " data_collator=collator,\n",
138
+ " )\n",
139
+ "\n",
140
+ " # step 8: train\n",
141
+ " trainer.train(\n",
142
+ " # resume_from_checkpoint=True\n",
143
+ " )\n",
144
+ "\n",
145
+ " loss_log = pd.DataFrame(trainer.state.log_history)\n",
146
+ " log_dir = './logs'\n",
147
+ " if not os.path.exists(log_dir):\n",
148
+ " os.mkdir(log_dir)\n",
149
+ " loss_log.to_csv(f\"{log_dir}/ie_task_finetune_log_{time.strftime('%Y%m%d-%H%M')}.csv\")\n",
150
+ "\n",
151
+ " # Step 9: Save the model\n",
152
+ " trainer.save_model(config.output_dir)\n",
153
+ "\n",
154
+ " return trainer\n",
155
+ " "
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "config = SFTconfig()\n",
165
+ "config.finetune_from_ckp_file = InferConfig().model_dir\n",
166
+ "config.sft_train_file = './data/my_train.json'\n",
167
+ "config.output_dir = './model_save/ie_task'\n",
168
+ "config.max_seq_len = 512\n",
169
+ "config.batch_size = 16\n",
170
+ "config.gradient_accumulation_steps = 4\n",
171
+ "config.logging_steps = 20\n",
172
+ "config.learning_rate = 5e-5\n",
173
+ "config.num_train_epochs = 6\n",
174
+ "config.save_steps = 3000\n",
175
+ "config.warmup_steps = 1000\n",
176
+ "print(config)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "trainer = sft_train(config)"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 1,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "import sys, os\n",
195
+ "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -2]\n",
196
+ "root = '/'.join(root)\n",
197
+ "if root not in sys.path:\n",
198
+ " sys.path.append(root)\n",
199
+ "import ujson, torch\n",
200
+ "from rich import progress\n",
201
+ "\n",
202
+ "from model.infer import ChatBot\n",
203
+ "from config import InferConfig\n",
204
+ "from utils.functions import f1_p_r_compute\n",
205
+ "inf_conf = InferConfig()\n",
206
+ "inf_conf.model_dir = './model_save/ie_task/'\n",
207
+ "bot = ChatBot(infer_config=inf_conf)\n"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 2,
213
+ "metadata": {},
214
+ "outputs": [
215
+ {
216
+ "name": "stdout",
217
+ "output_type": "stream",
218
+ "text": [
219
+ "[(傅淑云,民族,汉族),(傅淑云,出生地,上海),(傅淑云,出生日期,1915年)]\n"
220
+ ]
221
+ }
222
+ ],
223
+ "source": [
224
+ "ret = bot.chat('请抽取出给定句子中的所有三元组。给定句子:傅淑云,女,汉族,1915年出生,上海人')\n",
225
+ "print(ret)"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 3,
231
+ "metadata": {},
232
+ "outputs": [
233
+ {
234
+ "name": "stdout",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "[('傅淑云', '民族', '汉族'), ('傅淑云', '出生地', '上海'), ('傅淑云', '出生日期', '1915年')]\n"
238
+ ]
239
+ }
240
+ ],
241
+ "source": [
242
+ "def text_to_spo_list(sentence: str) -> str:\n",
243
+ " '''\n",
244
+ " 将输出转换为SPO列表,时间复杂度: O(n)\n",
245
+ " '''\n",
246
+ " spo_list = []\n",
247
+ " sentence = sentence.replace(',',',').replace('(','(').replace(')', ')') # 符号标准化\n",
248
+ "\n",
249
+ " cur_txt, cur_spo, started = '', [], False\n",
250
+ " for i, char in enumerate(sentence):\n",
251
+ " if char not in '[](),':\n",
252
+ " cur_txt += char\n",
253
+ " elif char == '(':\n",
254
+ " started = True\n",
255
+ " cur_txt, cur_spo = '' , []\n",
256
+ " elif char == ',' and started and len(cur_txt) > 0 and len(cur_spo) < 3:\n",
257
+ " cur_spo.append(cur_txt)\n",
258
+ " cur_txt = ''\n",
259
+ " elif char == ')' and started and len(cur_txt) > 0 and len(cur_spo) == 2:\n",
260
+ " cur_spo.append(cur_txt)\n",
261
+ " spo_list.append(tuple(cur_spo))\n",
262
+ " cur_spo = []\n",
263
+ " cur_txt = ''\n",
264
+ " started = False\n",
265
+ " return spo_list\n",
266
+ "print(text_to_spo_list(ret))"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": 4,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "test_data = []\n",
276
+ "with open('./data/test.json', 'r', encoding='utf-8') as f:\n",
277
+ " test_data = ujson.load(f)"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "execution_count": 5,
283
+ "metadata": {},
284
+ "outputs": [
285
+ {
286
+ "data": {
287
+ "text/plain": [
288
+ "[{'prompt': '请抽取出给定句子中的所有三元组。给定句子:查尔斯·阿兰基斯(charles aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部',\n",
289
+ " 'response': '[(查尔斯·阿兰基斯,出生地,圣地亚哥),(查尔斯·阿兰基斯,出生日期,1989年4月17日)]'},\n",
290
+ " {'prompt': '请抽取出给定句子中的所有三元组。给定句子:《离开》是由张宇谱曲,演唱',\n",
291
+ " 'response': '[(离开,歌手,张宇),(离开,作曲,张宇)]'}]"
292
+ ]
293
+ },
294
+ "execution_count": 5,
295
+ "metadata": {},
296
+ "output_type": "execute_result"
297
+ }
298
+ ],
299
+ "source": [
300
+ "test_data[0:2]"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": 6,
306
+ "metadata": {},
307
+ "outputs": [
308
+ {
309
+ "data": {
310
+ "application/vnd.jupyter.widget-view+json": {
311
+ "model_id": "bca40f71fcc34dda95eb97a6f48fea0c",
312
+ "version_major": 2,
313
+ "version_minor": 0
314
+ },
315
+ "text/plain": [
316
+ "Output()"
317
+ ]
318
+ },
319
+ "metadata": {},
320
+ "output_type": "display_data"
321
+ },
322
+ {
323
+ "data": {
324
+ "text/html": [
325
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
326
+ ],
327
+ "text/plain": []
328
+ },
329
+ "metadata": {},
330
+ "output_type": "display_data"
331
+ },
332
+ {
333
+ "data": {
334
+ "text/html": [
335
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
336
+ "</pre>\n"
337
+ ],
338
+ "text/plain": [
339
+ "\n"
340
+ ]
341
+ },
342
+ "metadata": {},
343
+ "output_type": "display_data"
344
+ }
345
+ ],
346
+ "source": [
347
+ "prompt_buffer, batch_size, n = [], 32, len(test_data)\n",
348
+ "traget_spo_list, predict_spo_list = [], []\n",
349
+ "for i, item in progress.track(enumerate(test_data), total=n):\n",
350
+ " prompt_buffer.append(item['prompt'])\n",
351
+ " traget_spo_list.append(\n",
352
+ " text_to_spo_list(item['response'])\n",
353
+ " )\n",
354
+ "\n",
355
+ " if len(prompt_buffer) == batch_size or i == n - 1:\n",
356
+ " torch.cuda.empty_cache()\n",
357
+ " model_pred = bot.chat(prompt_buffer)\n",
358
+ " model_pred = [text_to_spo_list(item) for item in model_pred]\n",
359
+ " predict_spo_list.extend(model_pred)\n",
360
+ " prompt_buffer = []"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": 7,
366
+ "metadata": {},
367
+ "outputs": [
368
+ {
369
+ "name": "stdout",
370
+ "output_type": "stream",
371
+ "text": [
372
+ "[[('查尔斯·阿兰基斯', '出生地', '圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')], [('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]] \n",
373
+ "\n",
374
+ "\n",
375
+ " [[('查尔斯·阿兰基斯', '国籍', '智利'), ('查尔斯·阿兰基斯', '出生地', '智利圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')], [('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]]\n"
376
+ ]
377
+ }
378
+ ],
379
+ "source": [
380
+ "print(traget_spo_list[0:2], '\\n\\n\\n',predict_spo_list[0:2])"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 8,
386
+ "metadata": {},
387
+ "outputs": [
388
+ {
389
+ "name": "stdout",
390
+ "output_type": "stream",
391
+ "text": [
392
+ "21636 21636\n"
393
+ ]
394
+ }
395
+ ],
396
+ "source": [
397
+ "print(len(predict_spo_list), len(traget_spo_list))"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 9,
403
+ "metadata": {},
404
+ "outputs": [
405
+ {
406
+ "name": "stdout",
407
+ "output_type": "stream",
408
+ "text": [
409
+ "f1: 0.74, precision: 0.75, recall: 0.73\n"
410
+ ]
411
+ }
412
+ ],
413
+ "source": [
414
+ "f1, p, r = f1_p_r_compute(predict_spo_list, traget_spo_list)\n",
415
+ "print(f\"f1: {f1:.2f}, precision: {p:.2f}, recall: {r:.2f}\")"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": 2,
421
+ "metadata": {},
422
+ "outputs": [
423
+ {
424
+ "data": {
425
+ "text/plain": [
426
+ "['你好,有什么我可以帮你的吗?',\n",
427
+ " '[(江苏省赣榆海洋经济开发区,成立日期,2003年1月28日)]',\n",
428
+ " '南方地区气候干燥,气候寒冷,冬季寒冷,夏季炎热,冬季寒冷的原因很多,可能是由于全球气候变暖导致的。\\n南方气候的变化可以引起天气的变化,例如气温下降、降雨增多、冷空气南下等。南方气候的变化可以促进气候的稳定,有利于经济发展和经济繁荣。\\n此外,南方地区的气候也可能受到自然灾害的影响,例如台风、台风、暴雨等,这些自然灾害会对南方气候产生影响。\\n总之,南方气候的变化是一个复杂的过程,需要综合考虑多方面因素,才能应对。']"
429
+ ]
430
+ },
431
+ "execution_count": 2,
432
+ "metadata": {},
433
+ "output_type": "execute_result"
434
+ }
435
+ ],
436
+ "source": [
437
+ "# 测试一下对话能力\n",
438
+ "bot.chat(['你好', '请抽取出给定句子中的所有三元组。给定句子:江苏省赣榆海洋经济开发区位于赣榆区青口镇临海而建,2003年1月28日,经江苏省人民政府《关于同意设立赣榆海洋经济开发区的批复》(苏政复〔2003〕14号)文件批准为全省首家省级海洋经济开发区,','如何看待最近南方天气突然变冷?'])"
439
+ ]
440
+ }
441
+ ],
442
+ "metadata": {
443
+ "kernelspec": {
444
+ "display_name": "py310",
445
+ "language": "python",
446
+ "name": "python3"
447
+ },
448
+ "language_info": {
449
+ "codemirror_mode": {
450
+ "name": "ipython",
451
+ "version": 3
452
+ },
453
+ "file_extension": ".py",
454
+ "mimetype": "text/x-python",
455
+ "name": "python",
456
+ "nbconvert_exporter": "python",
457
+ "pygments_lexer": "ipython3",
458
+ "version": "3.10.12"
459
+ }
460
+ },
461
+ "nbformat": 4,
462
+ "nbformat_minor": 2
463
+ }
img/api_example.png ADDED
img/dpo_loss.png ADDED
img/ie_task_chat.png ADDED
img/sentence_length.png ADDED
img/sft_loss.png ADDED
img/show1.png ADDED
img/stream_chat.gif ADDED

Git LFS Details

  • SHA256: c9a4eb95e7afbae5f4940a9b83c942725a0d4cf1eb8390938ff8a7282300b910
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
img/train_loss.png ADDED
model/__pycache__/chat_model.cpython-310.pyc ADDED
Binary file (2.7 kB). View file
 
model/__pycache__/infer.cpython-310.pyc ADDED
Binary file (3.73 kB). View file
 
model/chat_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, LongTensor
3
+ from transformers import T5ForConditionalGeneration, T5Config
4
+ from transformers import TextIteratorStreamer
5
+ from transformers.generation.configuration_utils import GenerationConfig
6
+
7
+ class TextToTextModel(T5ForConditionalGeneration):
8
+ def __init__(self, config: T5Config) -> None:
9
+ '''
10
+ TextToTextModel继承T5ForConditionalGeneration
11
+ '''
12
+ super().__init__(config)
13
+
14
+ @torch.no_grad()
15
+ def my_generate(self,
16
+ input_ids: LongTensor,
17
+ attention_mask: LongTensor,
18
+ max_seq_len: int=256,
19
+ search_type: str='beam',
20
+ streamer: TextIteratorStreamer=None,
21
+ ) -> Tensor:
22
+ '''
23
+ 自定义gennerate方法方便调用、测试
24
+ search_type: ['greedy', 'beam', 'sampling', 'contrastive', ]
25
+
26
+ - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
27
+ `do_sample=False`
28
+ - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
29
+ and `top_k>1`
30
+ - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
31
+ `do_sample=True`
32
+ - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
33
+ `do_sample=False`
34
+ - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
35
+ `num_beams>1` and `do_sample=True`
36
+ '''
37
+ generation_config = GenerationConfig()
38
+ generation_config.remove_invalid_values = True
39
+ generation_config.eos_token_id = 1
40
+ generation_config.pad_token_id = 0
41
+ generation_config.decoder_start_token_id = self.config.decoder_start_token_id
42
+ generation_config.max_new_tokens = max_seq_len
43
+ # generation_config.repetition_penalty = 1.1 # 重复词惩罚
44
+
45
+ if search_type == 'greedy':
46
+ generation_config.num_beams = 1
47
+ generation_config.do_sample = False
48
+ elif search_type == 'beam':
49
+ generation_config.top_k = 50
50
+ generation_config.num_beams = 5
51
+ generation_config.do_sample = True
52
+ generation_config.top_p = 0.95
53
+ generation_config.no_repeat_ngram_size = 4
54
+ generation_config.length_penalty = -2.0
55
+ generation_config.early_stopping = True
56
+ elif search_type == 'sampling':
57
+ generation_config.num_beams = 1
58
+ generation_config.do_sample = True
59
+ generation_config.top_k = 50
60
+ generation_config.temperature = 0.98 # 越低,贫富差距越大,越高(>1),越趋向于均匀分布
61
+ generation_config.top_p = 0.80
62
+ generation_config.no_repeat_ngram_size = 4
63
+ elif search_type == 'contrastive':
64
+ generation_config.penalty_alpha = 0.5
65
+ generation_config.top_k = 50
66
+
67
+ result = self.generate(
68
+ inputs=input_ids,
69
+ attention_mask=attention_mask,
70
+ generation_config=generation_config,
71
+ streamer=streamer,
72
+ )
73
+
74
+ return result
model/chat_model_config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import T5Config
2
+
3
+ class TextToTextModelConfig(T5Config):
4
+ model_type = 't5'
model/dataset.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from torch.utils.data import Dataset
4
+ from torch import LongTensor, cuda
5
+ from transformers import PreTrainedTokenizerFast
6
+ from fastparquet import ParquetFile
7
+ from torch.utils.data import DataLoader
8
+ from datasets import load_dataset
9
+ import datasets
10
+ import pyarrow.parquet as pq
11
+ from numpy import array, int64
12
+ from numpy.random import shuffle
13
+
14
+ # import sys
15
+ # sys.path.extend(['.', '..'])
16
+
17
+ from config import PROJECT_ROOT
18
+
19
+ class MyDataset(Dataset):
20
+
21
+ def __init__(self,
22
+ parquet_file: str,
23
+ tokenizer_dir: str,
24
+ keep_in_memory: bool=False,
25
+ max_seq_len: int=512,
26
+ buffer_size: int=40960,
27
+ ) -> None:
28
+ '''
29
+ keep_in_memory: 是否将parquet文件转换为pandas.DataFrame格式存放到内存,
30
+ False将使用迭代生成器(迭代生成器不支持打乱数据),减少大数据集内存占用
31
+ '''
32
+ super().__init__()
33
+
34
+ if cuda.device_count() >= 2 and not keep_in_memory:
35
+ raise ValueError(f'多GPU时使用MyDataset,参数keep_in_memory必须=True,否则无法进行分布式训练. 当前keep_in_memory={keep_in_memory}')
36
+
37
+ self.keep_in_memory = keep_in_memory
38
+ self.max_seq_len = max_seq_len
39
+
40
+ # 使用pyarrow.parquet读取,to_pandas、for遍历速度更快
41
+ parquet_table = pq.read_table(parquet_file)
42
+
43
+ # 获取数据集长度
44
+ self.length = parquet_table.num_rows
45
+
46
+ # 缓冲区大小不能超过数据长度
47
+ self.buffer_size = self.length if buffer_size > self.length else buffer_size
48
+
49
+ if keep_in_memory:
50
+ # 转化为pandas放到内存中
51
+ self.data = parquet_table.to_pandas()
52
+ else:
53
+ self.data = parquet_table
54
+
55
+ # 初始化tokenizer
56
+ self.tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
57
+
58
+ # 在这里初始化generator
59
+ self.sample_generator = self.item_generator()
60
+
61
+ def item_generator(self,) -> tuple:
62
+ '''
63
+ 一条数据的生成器,防止大数据集OOM
64
+ '''
65
+
66
+ parquet_table = self.data
67
+
68
+ # 生成器是死循环,不用退出,训练结束(epoch结束)会停止调用next()
69
+ buffer_list = []
70
+ while True:
71
+
72
+ for prompt, response in zip(parquet_table['prompt'], parquet_table['response']):
73
+
74
+ # 缓存数据不够,添加数据
75
+ if len(buffer_list) < self.buffer_size:
76
+ buffer_list.append( (prompt.as_py(), response.as_py()) )
77
+ continue
78
+
79
+ # 执行到这里,缓存区够了,打乱数据
80
+ shuffle(buffer_list)
81
+ for p, r in buffer_list:
82
+ # 在这里迭代
83
+ yield p, r
84
+
85
+ # 迭代完成,清空缓存区
86
+ buffer_list = []
87
+
88
+ def __getitem__(self, index):
89
+ '''
90
+ 返回一条样本
91
+ '''
92
+ if self.keep_in_memory:
93
+ data = self.data
94
+ prompt, response = data.iloc[index].prompt, data.iloc[index].response
95
+ else:
96
+ prompt, response = next(self.sample_generator)
97
+
98
+ max_seq_len = self.max_seq_len - 5 # len('[EOS]') = 5
99
+ # add an eos token note that end of resopnse, using in generate.
100
+ return f"{prompt[0: max_seq_len]}[EOS]", f"{response[0: max_seq_len]}[EOS]"
101
+
102
+ def collate_fn(self, data: list[list]) -> dict:
103
+ '''
104
+ 合并一个批次数据返回
105
+ '''
106
+ tokenizer = self.tokenizer
107
+
108
+ prompt = tokenizer([item[0] for item in data], padding=True, return_token_type_ids=False)
109
+ response = tokenizer([item[1] for item in data], padding=True, return_token_type_ids=False)
110
+
111
+ input_ids = array(prompt.input_ids, dtype=int64)
112
+ input_mask = array(prompt.attention_mask, dtype=int64)
113
+ target_ids = array(response.input_ids, dtype=int64)
114
+
115
+ ret = {
116
+ 'input_ids': LongTensor(input_ids),
117
+ 'input_mask': LongTensor(input_mask),
118
+ 'target_ids': LongTensor(target_ids),
119
+ }
120
+ return ret
121
+
122
+ def __len__(self) -> int:
123
+ return self.length
124
+
125
+ class ParquetDataset:
126
+
127
+ def __init__(self,
128
+ parquet_file: Union[str, dict],
129
+ tokenizer_dir: str,
130
+ keep_in_memory: bool=False,
131
+ cache_dir: str='./.cache',
132
+ buffer_size: int=10240,
133
+ max_len: int=512,
134
+ seed: int=23333
135
+ ) -> None:
136
+ '''
137
+ 使用huggingface的loaddataset方法加载,
138
+ parquet_file: 单个文件,此时只能使用dataset['train'],
139
+ 多个文件请用:parquet_file={'train': 'train.parquet', 'test': 'test.parquet', 'validation': 'validation.parquet'})
140
+ 其他用法见:https://huggingface.co/docs/datasets/loading
141
+ keep_in_memory: 是否将parquet文件转换为pandas.DataFrame格式存放到内存
142
+ '''
143
+ self.keep_in_memory = keep_in_memory
144
+ self.len_dict = self.__get_all_parquet_file_size(parquet_file=parquet_file)
145
+
146
+ self.max_len = max_len
147
+ self.tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_dir)
148
+
149
+ self.tokenizer = self.tokenizer
150
+
151
+ streaming = False if keep_in_memory else True
152
+ # streaming=True,否则大数据集OOM
153
+ dataset = load_dataset('parquet', data_files=parquet_file, cache_dir=cache_dir, streaming=streaming)
154
+
155
+ # 这里的batch_size不是训练的batch_size,是传递给precess_batch_func批处理的batch_size
156
+ dataset = dataset.map(self.precess_batch_func, batched=True, batch_size=buffer_size, \
157
+ remove_columns=['prompt', 'response'], fn_kwargs={'max_len': max_len})
158
+
159
+ dataset = dataset.with_format(type="torch")
160
+
161
+ if keep_in_memory:
162
+ dataset = dataset.shuffle(seed=seed, keep_in_memory=keep_in_memory)
163
+ else:
164
+ # 只能打乱缓冲区内的数据,不能打乱整个数据集,因此可以将缓存区设置稍微大一些
165
+ dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
166
+
167
+ self.dataset = dataset
168
+
169
+ @staticmethod
170
+ def precess_batch_func(item: dict, max_len: int=512) -> dict:
171
+ '''
172
+ 添加EOS
173
+ '''
174
+ max_len -= 5 # len('[EOS]') = 5
175
+ for i in range(len(item['prompt'])):
176
+ item['prompt'][i] = f"{item['prompt'][i][0: max_len]}[EOS]"
177
+ for i in range(len(item['response'])):
178
+ item['response'][i] = f"{item['response'][i][0: max_len]}[EOS]"
179
+
180
+ return {
181
+ 'prompt': item['prompt'],
182
+ 'response': item['response'],
183
+ }
184
+
185
+ def collate_fn(self, data: list[list]) -> dict:
186
+ '''
187
+ 合并一个批次数据返回
188
+ '''
189
+
190
+ tokenizer = self.tokenizer
191
+ prompt = [item['prompt'] for item in data ]
192
+ response = [item['response'] for item in data ]
193
+
194
+ # 按批次pad
195
+ prompt_encoded = tokenizer(prompt, padding=True, return_token_type_ids=False)
196
+ response_encoded = tokenizer(response, padding=True, return_token_type_ids=False)
197
+
198
+ input_ids = array(prompt_encoded.input_ids, dtype=int64)
199
+ input_mask = array(prompt_encoded.attention_mask, dtype=int64)
200
+ target_ids = array(response_encoded.input_ids, dtype=int64)
201
+
202
+ ret = {
203
+ 'input_ids': LongTensor(input_ids),
204
+ 'input_mask': LongTensor(input_mask),
205
+ 'target_ids': LongTensor(target_ids),
206
+ }
207
+ return ret
208
+ def __getitem__(self, index: str) -> datasets.Dataset:
209
+ '''
210
+ 魔术方法,实现下标访问,如:dataset['train']、dataset['validation']、dataset['test']
211
+ '''
212
+ return self.dataset[index]
213
+
214
+ def __get_all_parquet_file_size(self, parquet_file: Union[str, dict]) -> dict:
215
+ '''
216
+ 获取所有parquet file的长度
217
+ '''
218
+ len_dict = dict()
219
+ if type(parquet_file) is str:
220
+ train_len = self.__get_size_of_praquet(parquet_file)
221
+ len_dict['train'] = train_len
222
+
223
+ if type(parquet_file) is dict:
224
+ for split_type, file in parquet_file.items():
225
+ len_dict[split_type] = self.__get_size_of_praquet(file)
226
+
227
+ return len_dict
228
+
229
+ def __get_size_of_praquet(self, file_name: str) -> int:
230
+ '''
231
+ 获取一个parquet文件的行数
232
+ '''
233
+ parquet_data = pq.read_table(file_name)
234
+
235
+ return parquet_data.num_rows
236
+
237
+ def __len__(self) -> int:
238
+ '''
239
+ 魔术方法,如果只有一个数据集,返回默认数据集大小
240
+ '''
241
+ if len(self.len_dict) == 1:
242
+ return self.len_dict['train']
243
+ else:
244
+ raise Exception("this dataset contains many splited datasets, use `get_dataset_size(split_name)` function to get length, e.g: get_dataset_size('train')")
245
+
246
+ def get_dataset_size(self, split_name: str) -> int:
247
+ '''
248
+ 获取每个切分数据集的长度
249
+ split_name可取:train、validation、test
250
+ '''
251
+ return self.len_dict[split_name]
252
+
253
+ def get_tokenizer(self, ) -> PreTrainedTokenizerFast:
254
+ return self.tokenizer
255
+
256
+
257
+
258
+ if __name__ == '__main__':
259
+ parquet_file = PROJECT_ROOT + '/data/my_valid_dataset.parquet'
260
+ tokenizer_dir = PROJECT_ROOT + '/model_save/tokenizer'
261
+
262
+ # example 1:
263
+ dataset = MyDataset(parquet_file, tokenizer_dir, keep_in_memory=False, max_seq_len=128)
264
+ print('\nexample 1, dataset size: ', len(dataset))
265
+ dataloader = DataLoader(dataset, batch_size=32, collate_fn=dataset.collate_fn)
266
+
267
+ for epoch in range(2):
268
+ print('epoch: {}'.format(epoch))
269
+ for step, batch in enumerate(dataloader):
270
+ x, x_mask, y = batch['input_ids'], batch['input_mask'], batch['target_ids']
271
+ print('step:{}'.format(step), x.shape, x_mask.shape, y.shape)
272
+ if step == 5:
273
+ break
274
+
275
+
276
+ # exit(0)
277
+ # example 2:
278
+ dataset = ParquetDataset(parquet_file, tokenizer_dir, keep_in_memory=True, max_len=32)
279
+ dataloader = DataLoader(dataset['train'], batch_size=32, collate_fn=dataset.collate_fn)
280
+ print('\nexample 2, dataset size: ', dataset.get_dataset_size('train'))
281
+
282
+ for epoch in range(2):
283
+ print('epoch: {}'.format(epoch))
284
+ for step, batch in enumerate(dataloader):
285
+ x, x_mask, y = batch['input_ids'], batch['input_mask'], batch['target_ids']
286
+ print('step:{}'.format(step), x.shape, x_mask.shape, y.shape)
287
+ if step == 5:
288
+ break
289
+
290
+
model/infer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ import platform
4
+ from typing import Union
5
+ import torch
6
+
7
+ from transformers import TextIteratorStreamer,PreTrainedTokenizerFast
8
+ from safetensors.torch import load_model
9
+
10
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
+
12
+ # import 自定义类和函数
13
+ from model.chat_model import TextToTextModel
14
+ from utils.functions import get_T5_config
15
+
16
+ from config import InferConfig, T5ModelConfig
17
+
18
+ class ChatBot:
19
+ def __init__(self, infer_config: InferConfig) -> None:
20
+ '''
21
+ '''
22
+ self.infer_config = infer_config
23
+ # 初始化tokenizer
24
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(infer_config.model_dir)
25
+ self.tokenizer = tokenizer
26
+ self.encode = tokenizer.encode_plus
27
+ self.batch_decode = tokenizer.batch_decode
28
+ self.batch_encode_plus = tokenizer.batch_encode_plus
29
+
30
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
31
+
32
+ try:
33
+ model = TextToTextModel(t5_config)
34
+
35
+ if os.path.isdir(infer_config.model_dir):
36
+
37
+ # from_pretrained
38
+ model = model.from_pretrained(infer_config.model_dir)
39
+
40
+ elif infer_config.model_dir.endswith('.safetensors'):
41
+
42
+ # load safetensors
43
+ load_model(model, infer_config.model_dir)
44
+
45
+ else:
46
+
47
+ # load torch checkpoint
48
+ model.load_state_dict(torch.load(infer_config.model_dir))
49
+
50
+ self.model = model
51
+
52
+ except Exception as e:
53
+ print(str(e), 'transformers and pytorch load fail, try accelerate load function.')
54
+
55
+ empty_model = None
56
+ with init_empty_weights():
57
+ empty_model = TextToTextModel(t5_config)
58
+
59
+ self.model = load_checkpoint_and_dispatch(
60
+ model=empty_model,
61
+ checkpoint=infer_config.model_dir,
62
+ device_map='auto',
63
+ dtype=torch.float16,
64
+ )
65
+
66
+
67
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68
+ self.model.to(self.device)
69
+
70
+ self.streamer = TextIteratorStreamer(tokenizer=tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True)
71
+
72
+ def stream_chat(self, input_txt: str) -> TextIteratorStreamer:
73
+ '''
74
+ 流式对话,线程启动后可返回,通过迭代streamer获取生成的文字,仅支持greedy search
75
+ '''
76
+ encoded = self.encode(input_txt + '[EOS]')
77
+
78
+ input_ids = torch.LongTensor([encoded.input_ids]).to(self.device)
79
+ attention_mask = torch.LongTensor([encoded.attention_mask]).to(self.device)
80
+
81
+ generation_kwargs = {
82
+ 'input_ids': input_ids,
83
+ 'attention_mask': attention_mask,
84
+ 'max_seq_len': self.infer_config.max_seq_len,
85
+ 'streamer': self.streamer,
86
+ 'search_type': 'greedy',
87
+ }
88
+
89
+ thread = Thread(target=self.model.my_generate, kwargs=generation_kwargs)
90
+ thread.start()
91
+
92
+ return self.streamer
93
+
94
+ def chat(self, input_txt: Union[str, list[str]] ) -> Union[str, list[str]]:
95
+ '''
96
+ 非流式生成,可以使用beam search、beam sample等方法生成文本。
97
+ '''
98
+ if isinstance(input_txt, str):
99
+ input_txt = [input_txt]
100
+ elif not isinstance(input_txt, list):
101
+ raise Exception('input_txt mast be a str or list[str]')
102
+
103
+ # add EOS token
104
+ input_txts = [f"{txt}[EOS]" for txt in input_txt]
105
+ encoded = self.batch_encode_plus(input_txts, padding=True)
106
+ input_ids = torch.LongTensor(encoded.input_ids).to(self.device)
107
+ attention_mask = torch.LongTensor(encoded.attention_mask).to(self.device)
108
+
109
+ outputs = self.model.my_generate(
110
+ input_ids=input_ids,
111
+ attention_mask=attention_mask,
112
+ max_seq_len=self.infer_config.max_seq_len,
113
+ search_type='greedy',
114
+ )
115
+
116
+ outputs = self.batch_decode(outputs.cpu().numpy(), clean_up_tokenization_spaces=True, skip_special_tokens=True)
117
+
118
+ note = "我是一个参数很少的AI模型🥺,知识库较少,无法直接回答您的问题,换个问题试试吧👋"
119
+ outputs = [item if len(item) != 0 else note for item in outputs]
120
+
121
+ return outputs[0] if len(outputs) == 1 else outputs
model/trainer.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import signal
2
+ import sys
3
+ import os
4
+ import time
5
+ from typing import Union
6
+ import platform
7
+
8
+ from psutil import virtual_memory, cpu_count
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader
11
+ import torch
12
+ from rich.progress import Progress, TextColumn, BarColumn, TimeElapsedColumn, TimeRemainingColumn
13
+ from transformers import PreTrainedTokenizerFast
14
+ from torch_optimizer import Adafactor
15
+
16
+ # import accelerate
17
+ from accelerate import Accelerator
18
+ from accelerate.utils import set_seed
19
+
20
+ # import 自定义类和函数
21
+ from model.chat_model import TextToTextModel
22
+ from utils.logger import Logger
23
+ from model.dataset import MyDataset
24
+ from config import TrainConfig, T5ModelConfig
25
+ from utils.functions import (
26
+ get_bleu4_score,
27
+ save_model_config,
28
+ get_free_space_of_disk,
29
+ my_average,
30
+ get_path_of_suffix_files,
31
+ get_T5_config,
32
+ )
33
+
34
+ class ChatTrainer:
35
+ def __init__(self, train_config: TrainConfig, model_config: T5ModelConfig, ) -> None:
36
+
37
+ self.train_config = train_config
38
+ self.model_config = model_config
39
+
40
+ # file_name=None会自动生成以当前日期命名的log文件名
41
+ self.logger = Logger('chat_trainer', std_out=True, save2file=True, file_name=None)
42
+
43
+ self.model = None
44
+ self.accelerator = None
45
+
46
+ signal.signal(signal.SIGINT, self.process_exit_handler)
47
+
48
+ self.is_win_platform = True if platform.system().lower() == 'windows' else False
49
+
50
+ torch.manual_seed(train_config.seed)
51
+ torch.cuda.manual_seed_all(train_config.seed)
52
+
53
+ def process_exit_handler(self, signal_received, frame) -> None:
54
+ '''
55
+ 进程退出时的操作,保存模型
56
+ '''
57
+ if self.accelerator and self.model:
58
+ ask = "you are pressed `ctrl+c`, do you want to save checkpoint? Yes (y) or No (n)"
59
+ self.accelerator.print(ask)
60
+ ins = input()
61
+
62
+ if ins.lower() in ('yes', 'y'):
63
+
64
+ suffix = 'exit_save_{}'.format(str(time.strftime('%Y%m%d%H%M%S', time.localtime())))
65
+
66
+ self.accelerator.wait_for_everyone()
67
+ self.accelerator.save_state(output_dir=self.train_config.train_state_dir)
68
+
69
+ self.accelerator.print('model ckeck point has been saved in {}'.format(self.train_config.train_state_dir))
70
+
71
+ sys.exit(0)
72
+ else:
73
+ print('process not in trainingg, exit.')
74
+ sys.exit(0)
75
+
76
+ def save_model(self, suffix: Union[str, int]) -> None:
77
+ '''保存模型到文件
78
+ 注意:save_model不能放到is_main_process里面
79
+ e.g:
80
+ >>> self.save_model(epoch) # 在这里使用
81
+ >>> if accelerator.is_main_process:
82
+ >>> do_somthing()
83
+ '''
84
+ if self.model and self.accelerator:
85
+
86
+ # 先wait_for_everyone,再保存
87
+ self.accelerator.wait_for_everyone()
88
+
89
+ if self.accelerator.is_main_process:
90
+ unwrap_model = self.accelerator.unwrap_model(self.model)
91
+ model_dict = self.accelerator.get_state_dict(unwrap_model)
92
+ torch.save(model_dict, self.train_config.model_file.format(suffix))
93
+
94
+
95
+ def delete_early_checkpoint(self, epoch: int, keep_latest_n: int=3,) -> None:
96
+ '''
97
+ 删除最早的模型,最保留最近keep_latest_n个模型文件
98
+ '''
99
+ model_save_path = self.train_config.model_file
100
+ model_save_path = model_save_path.replace('\\', '/') # 针对win的路径,将\替换为/
101
+ model_save_path = '/'.join(model_save_path.split('/')[0: -1]) # 删除末尾文件名后缀
102
+
103
+ model_files = get_path_of_suffix_files(model_save_path, suffix='.bin', with_create_time=True)
104
+
105
+ # 进程异常退出保存模型文件不在删除范围
106
+ train_save_model_fils = []
107
+ for item in model_files:
108
+ if 'exit_save' not in item[0]:
109
+
110
+ # 大于当前epoch的文件不不删除
111
+ f_epoch = int(item[0].split('.')[-2])
112
+ if epoch >= f_epoch:
113
+ print(epoch, f_epoch, item)
114
+ train_save_model_fils.append(item)
115
+
116
+ train_save_model_fils.sort(key=lambda x: x[1]) # 按照时间从小到大排序
117
+
118
+ if len(train_save_model_fils) <= keep_latest_n:
119
+ return
120
+
121
+ to_delete_files = train_save_model_fils[0: -keep_latest_n]
122
+ for item in to_delete_files:
123
+ os.remove(item[0])
124
+
125
+
126
+ def train(self, is_keep_training: bool=False, is_finetune: bool=False) -> None:
127
+ '''
128
+ is_keep_training: 是否从断点处加载状态继续训练
129
+ is_finetune: 是否微调,微调的话可能需要冻结部分参数
130
+ '''
131
+ log = self.logger
132
+ train_config = self.train_config
133
+ save_steps = self.train_config.save_steps
134
+ logging_steps = self.train_config.logging_steps
135
+
136
+ # 梯度累计的步数
137
+ accumulation_steps = train_config.gradient_accumulation_steps
138
+
139
+ set_seed(train_config.seed)
140
+
141
+ accelerator = Accelerator(
142
+ mixed_precision=train_config.mixed_precision, # 混合精度
143
+ gradient_accumulation_steps=accumulation_steps, # 梯度累积
144
+ project_dir=train_config.train_state_dir,
145
+ )
146
+
147
+ # 根据剩余内存大小决定是否完全加载数据集到内存中
148
+ unuse_mem = virtual_memory().available / (1024 ** 3) # 单位:GB
149
+ unuse_disk = get_free_space_of_disk('./')
150
+
151
+ # 剩余内存≥48GB将把数据集留在内存中,因为2个显卡+全全部装载900多万的训练数据到内存需要大概43GB的CPU内存
152
+ # 如果不放在内存中,将会使用迭代器生成数据,CPU 内存小于16GB也可以运行,但是不支持顺序打乱。
153
+ # 多GPU keep_in_memory必须=True,否则无法进行分布式训练
154
+ keep_in_memory = True if unuse_mem >= 48.0 or torch.cuda.device_count() >= 2 else False
155
+
156
+ if accelerator.is_main_process:
157
+ log.info('cpu memory available: {:.2f} GB, disk space available: {:.2f} GB, keep dataset in memory: {}.'\
158
+ .format(unuse_mem, unuse_disk, keep_in_memory), save_to_file=True)
159
+ log.info('operation: {}, keep training: {}, loading datasets ...'.format('finetune' if is_finetune else 'train', is_keep_training))
160
+
161
+ # args for dataloader
162
+ num_workers = 0
163
+ # if not self.is_win_platform:
164
+ # cpu_cnt = cpu_count(logical=False)
165
+ # gpu_cnt = torch.cuda.device_count()
166
+ # if cpu_cnt >= 8 * gpu_cnt:
167
+ # # num_workers = 4 x number of available GPUs
168
+ # num_workers = int(4 * gpu_cnt)
169
+ # else:
170
+ # num_workers = int(cpu_cnt // 2)
171
+
172
+ train_dataset = MyDataset(
173
+ parquet_file=train_config.train_file,
174
+ tokenizer_dir=train_config.tokenizer_dir,
175
+ keep_in_memory=keep_in_memory,
176
+ max_seq_len=train_config.max_seq_len,
177
+ )
178
+ valid_dataset = MyDataset(
179
+ parquet_file=train_config.validation_file,
180
+ tokenizer_dir=train_config.tokenizer_dir,
181
+ keep_in_memory=keep_in_memory,
182
+ max_seq_len=train_config.max_seq_len,
183
+ )
184
+
185
+ batch_size = train_config.batch_size_per_gpu
186
+
187
+ train_dataloader = DataLoader(
188
+ train_dataset,
189
+ batch_size=batch_size,
190
+ shuffle=True,
191
+ collate_fn=train_dataset.collate_fn,
192
+ pin_memory=False,
193
+ num_workers=num_workers, #设置>1会导致cpu内存缓慢增涨,最后OOM,后面再研究为什么,num_workers=4,一个epoch只减少30分钟
194
+ )
195
+ valid_dataloader = DataLoader(
196
+ valid_dataset,
197
+ batch_size=batch_size,
198
+ shuffle=False,
199
+ collate_fn=valid_dataset.collate_fn,
200
+ pin_memory=False,
201
+ num_workers=num_workers,
202
+ )
203
+
204
+ device = accelerator.device
205
+ log.info('using device: {} '.format(str(device)), save_to_file=True)
206
+
207
+
208
+ # T5: All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
209
+ tokenizer = train_dataset.tokenizer
210
+ decoder_start_token_id = tokenizer.pad_token_id
211
+
212
+ # for t5, set decoder_start_token_id = pad_token_id
213
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=decoder_start_token_id, eos_token_id=tokenizer.eos_token_id)
214
+
215
+ model = TextToTextModel(t5_config)
216
+
217
+ # 微调加载的模型并冻结embedding和encoder
218
+ if is_finetune:
219
+ model.load_state_dict(torch.load(train_config.finetune_from_ckp_file))
220
+ # print(model)
221
+
222
+ layers_to_freeze = [model.shared, model.encoder]
223
+
224
+ for layer in layers_to_freeze:
225
+ for param in layer.parameters():
226
+ param.requires_grad = False
227
+
228
+ # 保存模型配置,方便修改配置后恢复
229
+ save_model_config(t5_config.to_diff_dict(), train_config.model_config_file)
230
+
231
+ # T5训练,论文推荐使用Adafactor
232
+ optimizer = Adafactor(params=model.parameters(), lr=train_config.learn_rate)
233
+
234
+
235
+ # 获取当前机器有多少个GPU,默认全部使用
236
+ num_gpus_used = accelerator.state.num_processes
237
+
238
+ # 单机多卡,每个step总共的batch_size = batch_size_per_gpu * num_gpus_used
239
+ # total_batch_size 初始化为batch_size_per_gpu真的只有CPU的情况
240
+ total_batch_size = train_config.batch_size_per_gpu
241
+ if num_gpus_used >= 1:
242
+ total_batch_size = num_gpus_used * train_config.batch_size_per_gpu
243
+
244
+ steps_per_epoch = int(np.ceil(len(train_dataset) // total_batch_size))
245
+ eval_steps = int(np.ceil(len(valid_dataset) // total_batch_size))
246
+
247
+ if accelerator.is_main_process:
248
+ log.info('train dataset size: {}, steps per epoch:{}; validation dataset size: {}, steps per validation: {}; datalodater num_workers: {}.'\
249
+ .format(len(train_dataset), steps_per_epoch, len(valid_dataset), eval_steps, num_workers), save_to_file=True)
250
+
251
+
252
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
253
+ optimizer=optimizer,
254
+ max_lr=train_config.div_factor * train_config.learn_rate,
255
+ epochs=train_config.epochs,
256
+ steps_per_epoch=int(np.ceil( len(train_dataset) / (batch_size * accumulation_steps) )), # 梯度累积相当于增大了batch_size
257
+ div_factor=train_config.div_factor,
258
+ cycle_momentum=False,
259
+ )
260
+
261
+ model, optimizer, lr_scheduler, train_dataloader, valid_dataloader = accelerator.prepare(
262
+ model,
263
+ optimizer,
264
+ lr_scheduler,
265
+ train_dataloader,
266
+ valid_dataloader,
267
+ )
268
+
269
+ if is_keep_training:
270
+ accelerator.load_state(input_dir=train_config.train_state_dir)
271
+ accelerator.register_for_checkpointing(lr_scheduler)
272
+
273
+ self.model = model
274
+ self.accelerator = accelerator
275
+
276
+ best_bleu4 = 0.0
277
+ best_epoch = 0
278
+ epoch_loss_list = []
279
+
280
+ # 添加进度条,只在主进程更新
281
+ if accelerator.is_main_process:
282
+ progress = Progress(TextColumn("[progress.description]{task.description}"),
283
+ BarColumn(),
284
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
285
+ TimeRemainingColumn(),
286
+ TimeElapsedColumn(),
287
+ TextColumn("[bold blue]{task.fields[show_info]}"),
288
+ refresh_per_second=1, # 每1秒钟更新一次,不要频繁更新
289
+ )
290
+
291
+ epoch_progress = progress.add_task(description='epoch: ', show_info='', total=train_config.epochs)
292
+ steps_progress = progress.add_task(description='steps: ', show_info='', \
293
+ total=np.ceil(steps_per_epoch / logging_steps))
294
+ eval_progress = progress.add_task(description='evaluate: ', show_info='', total=eval_steps, visible=False)
295
+
296
+ self.progress = progress
297
+ self.eval_progress = eval_progress
298
+
299
+ progress.start()
300
+
301
+ # end if
302
+
303
+ for epoch in range(train_config.epochs):
304
+
305
+ if accelerator.is_main_process:
306
+ epoch_show_txt = 'epoch: {}/{}, avg_loss: {:.6f}, best_epoch: {}, best_bleu: {}'.format(
307
+ epoch, train_config.epochs, my_average(epoch_loss_list), best_epoch, best_bleu4
308
+ )
309
+ progress.update(epoch_progress, show_info=epoch_show_txt)
310
+ progress.reset(steps_progress)
311
+
312
+ epoch_loss_list = []
313
+ model.train()
314
+
315
+ # torch.cuda.empty_cache()
316
+
317
+ for step, batch_data in enumerate(train_dataloader):
318
+
319
+ input_ids, input_mask = batch_data['input_ids'], batch_data['input_mask']
320
+ target_ids = batch_data['target_ids']
321
+ # for t5 model, all labels set to `-100` are ignored (masked)
322
+ target_ids[target_ids == decoder_start_token_id] = -100
323
+
324
+ outputs = model(
325
+ input_ids=input_ids,
326
+ attention_mask=input_mask,
327
+ labels=target_ids,
328
+ )
329
+
330
+ loss = outputs.loss.mean() / accumulation_steps
331
+
332
+ # attention here! loss.backward()
333
+ accelerator.backward(loss)
334
+
335
+ # 梯度累计
336
+ if (step + 1) % accumulation_steps == 0:
337
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
338
+
339
+ optimizer.step()
340
+ lr_scheduler.step()
341
+ optimizer.zero_grad()
342
+
343
+ # 每隔save_steps步保存一次模型
344
+ if (step + 1) % save_steps == 0 or step == steps_per_epoch:
345
+ self.save_model('epoch_{}_latest'.format(epoch))
346
+ accelerator.save_state(output_dir=train_config.train_state_dir)
347
+
348
+ # ==================================以下记录loss到日志============================================
349
+ # 每n步更新一次,避免频繁的cpu-gpu数据复制
350
+ # 参考:https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#avoid-unnecessary-cpu-gpu-synchronization
351
+
352
+ if step % logging_steps == 0 or step == steps_per_epoch:
353
+
354
+ loss_cpu = loss.detach().item() * accumulation_steps
355
+ epoch_loss_list.append(loss_cpu)
356
+
357
+ info_txt = 'training loss: epoch:{}, step:{}, loss:{}, device:{}'.\
358
+ format(epoch, step, loss_cpu, str(accelerator.device))
359
+
360
+ log.info(info_txt, std_out=False, save_to_file=True) # 保存 loss 到文件
361
+
362
+ # 更新进度条
363
+ if accelerator.is_main_process:
364
+ step_show_txt = 'step: {}/{}, loss: {:.6f}'.format(step, steps_per_epoch, loss_cpu)
365
+ progress.advance(steps_progress, advance=1)
366
+ progress.update(steps_progress, show_info=step_show_txt)
367
+
368
+ # ==================================以上记录loss到日志============================================
369
+
370
+ # if step >= 20:break
371
+
372
+ # end for batch setps
373
+
374
+ model.eval()
375
+
376
+ cur_bleu4_score = self.evaluate(
377
+ model=model,
378
+ tokenizer=tokenizer,
379
+ valid_dataloader=valid_dataloader,
380
+ accelerator=accelerator,
381
+ eval_steps=eval_steps,
382
+ )
383
+
384
+ # save model
385
+ if cur_bleu4_score >= best_bleu4:
386
+
387
+ best_bleu4 = cur_bleu4_score
388
+ best_epoch = epoch
389
+ # 最多保存最近keep_latest_n_ckp个模型文件
390
+ # self.delete_early_checkpoint(epoch=epoch, keep_latest_n=train_config.keep_latest_n_ckp)
391
+ self.save_model('best')
392
+ accelerator.save_state(output_dir=train_config.train_state_dir)
393
+
394
+ # 每个epoch打印一下日志
395
+ if accelerator.is_main_process:
396
+
397
+ progress.advance(epoch_progress, advance=1)
398
+ info_txt = 'epoch log: epoch:{}, avg_loss:{}, cur_bleu4:{}, best_bleu4:{}, best_epoch:{}'.\
399
+ format(epoch, my_average(epoch_loss_list), cur_bleu4_score, best_bleu4, best_epoch)
400
+ # log.info(info_txt, std_out=True, save_to_file=True)
401
+ self.print_and_log(info_txt, accelerator)
402
+
403
+
404
+ def evaluate(self,
405
+ model: TextToTextModel,
406
+ tokenizer: PreTrainedTokenizerFast,
407
+ valid_dataloader: DataLoader,
408
+ accelerator: Accelerator,
409
+ eval_steps: int,
410
+ ) -> float:
411
+
412
+ '''
413
+ 评估,返回平均的bleu分数
414
+ '''
415
+ max_seq_len = self.train_config.max_seq_len
416
+ batch_decode = tokenizer.batch_decode
417
+ bleu4_scores = []
418
+
419
+ if accelerator.is_main_process:
420
+ self.progress.reset(self.eval_progress)
421
+ self.progress.update(self.eval_progress, visible=True)
422
+
423
+ with torch.no_grad():
424
+ for step, batch_data in enumerate(valid_dataloader):
425
+
426
+ if accelerator.is_main_process:
427
+ self.progress.advance(self.eval_progress, advance=1)
428
+ self.progress.update(self.eval_progress, show_info='step: {}/{}'.format(step, eval_steps))
429
+
430
+ input_ids, input_mask = batch_data['input_ids'], batch_data['input_mask']
431
+ target_ids = batch_data['target_ids']
432
+
433
+ outputs = accelerator.unwrap_model(model).my_generate(
434
+ input_ids=input_ids,
435
+ attention_mask=input_mask,
436
+ max_seq_len=max_seq_len,
437
+ )
438
+
439
+ # gather data from multi-gpus (used when in ddp mode)
440
+ outputs = accelerator.gather_for_metrics(outputs).detach().cpu().numpy()
441
+ target_ids = accelerator.gather_for_metrics(target_ids).detach().cpu().numpy()
442
+
443
+ outputs = batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
444
+ target_ids = batch_decode(target_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
445
+
446
+ # print(outputs, target_ids)
447
+
448
+ bleu4_scores = [get_bleu4_score(reference=target_ids[i], outputs=outputs[i]) for i in range(len(target_ids))]
449
+ bleu4_scores.extend(bleu4_scores)
450
+
451
+ # if step >= 5: break
452
+
453
+ avg_bleu4_score = my_average(bleu4_scores)
454
+ if accelerator.is_main_process:
455
+ self.progress.update(self.eval_progress, show_info='bleu4 score: {}'.format(avg_bleu4_score))
456
+ self.progress.update(self.eval_progress, visible=False)
457
+
458
+ return avg_bleu4_score
459
+
460
+ def test(self, best_epoch: int=0) -> None:
461
+ '''
462
+ '''
463
+ import os
464
+
465
+ train_config = self.train_config
466
+ log = self.logger
467
+
468
+ # args for dataloader
469
+ num_workers = 0 if self.is_win_platform else 4
470
+
471
+ test_dataset = MyDataset(
472
+ parquet_file=train_config.train_file,
473
+ tokenizer_dir=train_config.tokenizer_dir,
474
+ keep_in_memory=False if self.is_win_platform else True,
475
+ max_seq_len=train_config.max_seq_len,
476
+ )
477
+
478
+ test_dataloader = DataLoader(
479
+ test_dataset,
480
+ batch_size=train_config.batch_size_per_gpu,
481
+ shuffle=False,
482
+ collate_fn=test_dataset.collate_fn,
483
+ pin_memory=False,
484
+ num_workers=num_workers,
485
+ )
486
+
487
+ log.info('test dataset size: {}.'.format(len(test_dataset)), save_to_file=True)
488
+
489
+ set_seed(train_config.seed)
490
+ accelerator = Accelerator(mixed_precision=train_config.mixed_precision)
491
+ device = accelerator.device
492
+ log.info('using device: {} '.format(str(device)), save_to_file=True)
493
+
494
+ # 获取当前运行使用了多少个GPU
495
+ num_gpus_used = accelerator.state.num_processes
496
+
497
+ # 单机多卡,每个step总共的batch_size = batch_size_per_gpu * num_gpus_used
498
+ # total_batch_size 初始化为batch_size_per_gpu真的只有CPU的情况
499
+ total_batch_size = train_config.batch_size_per_gpu
500
+ if num_gpus_used >= 1:
501
+ total_batch_size = num_gpus_used * train_config.batch_size_per_gpu
502
+
503
+ # T5: All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
504
+ tokenizer = test_dataset.tokenizer
505
+
506
+ model_file = train_config.model_file.format(best_epoch)
507
+ if os.path.isdir(model_file):
508
+ # 传入文件夹则 from_pretrained
509
+ model = TextToTextModel.from_pretrained(model_file)
510
+ else:
511
+ # load_state_dict
512
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
513
+ model = TextToTextModel(t5_config)
514
+ model.load_state_dict(torch.load(model_file, map_location='cpu')) # set cpu for no exception
515
+
516
+ model, test_dataloader = accelerator.prepare(
517
+ model,
518
+ test_dataloader,
519
+ )
520
+
521
+ steps = int(np.ceil(len(test_dataset) // total_batch_size))
522
+
523
+ bleu4 = 0.0
524
+ bleu4_scores = []
525
+ batch_decode = tokenizer.batch_decode
526
+ max_seq_len = self.train_config.max_seq_len
527
+ model.eval()
528
+
529
+ if accelerator.is_main_process:
530
+ progress = Progress(TextColumn("[progress.description]{task.description}"),
531
+ BarColumn(),
532
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
533
+ TimeRemainingColumn(),
534
+ TimeElapsedColumn(),
535
+ TextColumn("[bold blue]{task.fields[show_info]}"),
536
+ refresh_per_second=1.0,
537
+ )
538
+
539
+ steps_progress = progress.add_task(description='steps: ', show_info='', total=steps)
540
+ progress.start()
541
+
542
+ with torch.no_grad():
543
+ for step, batch_data in enumerate(test_dataloader):
544
+
545
+ if accelerator.is_main_process:
546
+ progress.advance(steps_progress, advance=1)
547
+ progress.update(steps_progress, show_info='step: {}/{}'.format(step, steps))
548
+
549
+ input_ids, input_mask = batch_data['input_ids'], batch_data['input_mask']
550
+ target_ids = batch_data['target_ids']
551
+
552
+ # s = time.time()
553
+ outputs = accelerator.unwrap_model(model).my_generate(
554
+ input_ids=input_ids,
555
+ attention_mask=input_mask,
556
+ max_seq_len=max_seq_len,
557
+ )
558
+ # accelerator.print('generate used: {}'.format(time.time() - s))
559
+
560
+ # gather data from multi-gpus (used when in ddp mode)
561
+ outputs = accelerator.gather_for_metrics(outputs).cpu().numpy()
562
+ target_ids = accelerator.gather_for_metrics(target_ids).cpu().numpy()
563
+
564
+ outputs = batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)
565
+ target_ids = batch_decode(target_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
566
+
567
+ # print('outputs: {}'.format(outputs[0:5]))
568
+ # print('target_ids: {}'.format(target_ids[0:5]))
569
+ # print()
570
+
571
+
572
+ bleu4_scores = [get_bleu4_score(reference=target_ids[i], outputs=outputs[i]) for i in range(len(target_ids))]
573
+ bleu4_scores.extend(bleu4_scores)
574
+
575
+ # if step >= 10: break
576
+
577
+ avg_bleu4_score = my_average(bleu4_scores)
578
+ if accelerator.is_main_process:
579
+ progress.update(steps_progress, show_info='bleu4 score: {}'.format(avg_bleu4_score))
580
+
581
+ info_txt = 'test_dataset_size: {}, avg_bleu4_score:{}.'.format(len(test_dataset), avg_bleu4_score)
582
+ log.info(info_txt, save_to_file=True)
583
+
584
+ return avg_bleu4_score
585
+
586
+
587
+ def print_and_log(self, info: str, accelerator: Accelerator=None) -> None:
588
+ '''
589
+ 使用accelerator.print, 否则多进程打印会异常
590
+ '''
591
+ if not accelerator:
592
+ print(info)
593
+ else:
594
+ accelerator.print(info)
595
+ self.logger.info(info, std_out=False, save_to_file=True)
596
+
597
+ if __name__ == '__main__':
598
+
599
+ # trainer = ChatTrainer()
600
+ train_config = TrainConfig()
601
+ model_config = T5ModelConfig()
602
+
603
+ chat_trainer = ChatTrainer(train_config=train_config, model_config=model_config)
604
+
605
+ chat_trainer.train()
606
+ # chat_trainer.test(best_epoch=0)
model_save/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
model_save/README.md ADDED
The diff for this file is too large to render. See raw diff
 
model_save/config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./model_save/dpo/",
3
+ "architectures": [
4
+ "TextToTextModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForSeq2SeqLM": "modeling_chat_model.TextToTextModel"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 3072,
11
+ "d_kv": 64,
12
+ "d_model": 768,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 1.0,
19
+ "is_encoder_decoder": true,
20
+ "is_gated_act": false,
21
+ "layer_norm_epsilon": 1e-06,
22
+ "model_type": "t5",
23
+ "num_decoder_layers": 10,
24
+ "num_heads": 12,
25
+ "num_layers": 10,
26
+ "pad_token_id": 0,
27
+ "relative_attention_max_distance": 128,
28
+ "relative_attention_num_buckets": 32,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.36.2",
31
+ "use_cache": true,
32
+ "vocab_size": 29298
33
+ }
model_save/configuration_chat_model.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import T5Config
2
+
3
+ class TextToTextModelConfig(T5Config):
4
+ model_type = 't5'
model_save/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.36.2"
7
+ }
model_save/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:054caeae92bcc13f0b6e7a12f86e75c8e18117279ecd89c4aa1f8ac74c95c02a
3
+ size 750794624
model_save/modeling_chat_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, LongTensor
3
+ from transformers import T5ForConditionalGeneration, T5Config
4
+ from transformers import TextIteratorStreamer
5
+ from transformers.generation.configuration_utils import GenerationConfig
6
+
7
+ class TextToTextModel(T5ForConditionalGeneration):
8
+ def __init__(self, config: T5Config) -> None:
9
+ '''
10
+ TextToTextModel继承T5ForConditionalGeneration
11
+ '''
12
+ super().__init__(config)
13
+
14
+ @torch.no_grad()
15
+ def my_generate(self,
16
+ input_ids: LongTensor,
17
+ attention_mask: LongTensor,
18
+ max_seq_len: int=256,
19
+ search_type: str='beam',
20
+ streamer: TextIteratorStreamer=None,
21
+ ) -> Tensor:
22
+ '''
23
+ 自定义gennerate方法方便调用、测试
24
+ search_type: ['greedy', 'beam', 'sampling', 'contrastive', ]
25
+
26
+ - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
27
+ `do_sample=False`
28
+ - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
29
+ and `top_k>1`
30
+ - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
31
+ `do_sample=True`
32
+ - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
33
+ `do_sample=False`
34
+ - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
35
+ `num_beams>1` and `do_sample=True`
36
+ '''
37
+ generation_config = GenerationConfig()
38
+ generation_config.remove_invalid_values = True
39
+ generation_config.eos_token_id = 1
40
+ generation_config.pad_token_id = 0
41
+ generation_config.decoder_start_token_id = self.config.decoder_start_token_id
42
+ generation_config.max_new_tokens = max_seq_len
43
+ # generation_config.repetition_penalty = 1.1 # 重复词惩罚
44
+
45
+ if search_type == 'greedy':
46
+ generation_config.num_beams = 1
47
+ generation_config.do_sample = False
48
+ elif search_type == 'beam':
49
+ generation_config.top_k = 50
50
+ generation_config.num_beams = 5
51
+ generation_config.do_sample = True
52
+ generation_config.top_p = 0.95
53
+ generation_config.no_repeat_ngram_size = 4
54
+ generation_config.length_penalty = -2.0
55
+ generation_config.early_stopping = True
56
+ elif search_type == 'sampling':
57
+ generation_config.num_beams = 1
58
+ generation_config.do_sample = True
59
+ generation_config.top_k = 50
60
+ generation_config.temperature = 0.98 # 越低概率越趋向于均匀分布
61
+ generation_config.top_p = 0.80
62
+ generation_config.no_repeat_ngram_size = 4
63
+ elif search_type == 'contrastive':
64
+ generation_config.penalty_alpha = 0.5
65
+ generation_config.top_k = 50
66
+
67
+ result = self.generate(
68
+ inputs=input_ids,
69
+ attention_mask=attention_mask,
70
+ generation_config=generation_config,
71
+ streamer=streamer,
72
+ )
73
+
74
+ return result
model_save/put_model_files_here ADDED
File without changes
model_save/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "eos_token": "[EOS]",
3
+ "pad_token": "[PAD]",
4
+ "unk_token": "[UNK]"
5
+ }
model_save/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model_save/tokenizer_config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[EOS]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[SEP]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[BOS]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[CLS]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "5": {
44
+ "content": "[MASK]",
45
+ "lstrip": false,
46
+ "normalized": false,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "6": {
52
+ "content": "[UNK]",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": true
58
+ }
59
+ },
60
+ "clean_up_tokenization_spaces": true,
61
+ "eos_token": "[EOS]",
62
+ "model_max_length": 1000000000000000019884624838656,
63
+ "pad_token": "[PAD]",
64
+ "tokenizer_class": "PreTrainedTokenizerFast",
65
+ "unk_token": "[UNK]"
66
+ }
pre_train.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import time
3
+ import os
4
+ import pandas as pd
5
+ from dataclasses import dataclass
6
+ import torch
7
+ from typing import Dict
8
+
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
12
+
13
+ from transformers.generation.configuration_utils import GenerationConfig
14
+ from datasets import Dataset, load_dataset
15
+
16
+ from model.chat_model import TextToTextModel
17
+ from model.dataset import MyDataset
18
+ from config import TrainConfig, T5ModelConfig
19
+
20
+ from utils.functions import json_to_dataclass, get_T5_config, MyTrainerCallback
21
+
22
+ tqdm.pandas()
23
+
24
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
25
+
26
+ def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
27
+ """
28
+ 加载数据集
29
+ """
30
+ dataset = load_dataset('parquet', data_files=file, split=split, cache_dir=cache_dir)
31
+
32
+ def tokens_to_ids(samples: dict) -> Dict[str, str]:
33
+
34
+ eos_token_id = tokenizer.eos_token_id
35
+
36
+ batch_prompt = samples['prompt']
37
+ batch_response = samples['response']
38
+
39
+ encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False,)
40
+ encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False,)
41
+
42
+ # vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
43
+ input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
44
+ labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
45
+
46
+ return {
47
+ 'input_ids': input_ids,
48
+ 'labels': labels,
49
+ }
50
+
51
+ dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
52
+
53
+ return dataset
54
+
55
+ def pre_train(config: TrainConfig) -> None:
56
+
57
+ # step 1. 加载tokenizer
58
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
59
+
60
+ # step 2. 加载模型配置文件
61
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
62
+
63
+ # step 3. 初始化模型
64
+ model = TextToTextModel(t5_config)
65
+
66
+ # Step 4: Load my dataset
67
+ dataset = get_dataset(file=config.train_file, split='train', tokenizer=tokenizer)
68
+
69
+ # Step 5: Define the training arguments
70
+
71
+ # T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
72
+ # huggingface官网的sft工具适用于language model/LM模型
73
+
74
+ generation_config = GenerationConfig()
75
+ generation_config.remove_invalid_values = True
76
+ generation_config.eos_token_id = tokenizer.eos_token_id
77
+ generation_config.pad_token_id = tokenizer.pad_token_id
78
+ generation_config.decoder_start_token_id = tokenizer.pad_token_id
79
+ generation_config.max_new_tokens = 320
80
+ generation_config.num_beams = 1 # greedy search
81
+ generation_config.do_sample = False # greedy search
82
+
83
+ training_args = Seq2SeqTrainingArguments(
84
+ output_dir=config.output_dir,
85
+ per_device_train_batch_size=config.batch_size_per_gpu,
86
+ auto_find_batch_size=True, # 防止OOM
87
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
88
+ learning_rate=config.learn_rate,
89
+ logging_steps=config.logging_steps,
90
+ num_train_epochs=config.epochs,
91
+ optim="adafactor",
92
+ report_to='tensorboard',
93
+ log_level='info',
94
+ save_steps=config.save_steps,
95
+ save_total_limit=3,
96
+ fp16=True if config.mixed_precision == 'fp16' else False,
97
+ bf16=True if config.mixed_precision == 'bf16' else False,
98
+ logging_first_step=True,
99
+ warmup_steps=config.warmup_steps,
100
+ seed=config.seed,
101
+ generation_config=generation_config,
102
+ )
103
+
104
+ # step 6: init my collator,
105
+ collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
106
+ empty_cuda_cahce = MyTrainerCallback()
107
+
108
+ # Step 7: Define the Trainer
109
+ trainer = Seq2SeqTrainer(
110
+ model=model,
111
+ args=training_args,
112
+ train_dataset=dataset,
113
+ tokenizer=tokenizer,
114
+ data_collator=collator,
115
+ callbacks=[empty_cuda_cahce],
116
+ )
117
+
118
+ # step 8: train
119
+ trainer.train(
120
+ # resume_from_checkpoint=True
121
+ )
122
+
123
+ #step 9: save log
124
+ loss_log = pd.DataFrame(trainer.state.log_history)
125
+ log_dir = './logs'
126
+ if not os.path.exists(log_dir):
127
+ os.mkdir(log_dir)
128
+ loss_log.to_csv(f"{log_dir}/pre_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
129
+
130
+ # Step 10: Save the model
131
+ trainer.save_model(config.output_dir)
132
+
133
+
134
+ if __name__ == '__main__':
135
+ config = TrainConfig()
136
+ pre_train(config)
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.25.0
2
+ colorlog==6.8.0
3
+ datasets==2.15.0
4
+ datasketch==1.6.4
5
+ fastapi==0.109.1
6
+ fastparquet==2023.10.1
7
+ fire==0.5.0
8
+ jieba==0.42.1
9
+ matplotlib==3.8.2
10
+ modelscope==1.11.1
11
+ nltk==3.8.1
12
+ numpy==1.26.2
13
+ opencc_python_reimplemented==0.1.7
14
+ pandas==2.1.4
15
+ peft==0.6.2
16
+ psutil==5.9.6
17
+ pyarrow==14.0.1
18
+ pydantic==2.5.2
19
+ rich==13.7.0
20
+ safetensors==0.4.1
21
+ sentencepiece==0.1.99
22
+ tokenizers==0.15.0
23
+ torch==2.1.1
24
+ torch_optimizer==0.3.0
25
+ tqdm==4.66.1
26
+ transformers==4.36.0
27
+ trl==0.7.4
28
+ ujson==5.8.0
29
+ uvicorn==0.24.0.post1
sft_train.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import Dict
3
+ import time
4
+ import os
5
+ import pandas as pd
6
+ import numpy as np
7
+ import torch
8
+ from datasets import Dataset, load_dataset
9
+ from peft import LoraConfig
10
+ from tqdm import tqdm
11
+ from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments
12
+ from transformers.generation.configuration_utils import GenerationConfig
13
+
14
+ from model.chat_model import TextToTextModel
15
+ from config import SFTconfig, T5ModelConfig
16
+ from utils.functions import get_T5_config, MyTrainerCallback
17
+
18
+ tqdm.pandas()
19
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
20
+
21
+ def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
22
+ """
23
+ 加载数据集
24
+ """
25
+
26
+ # 加载json数据集,如果要加载parquet,更改为'parquet'即可
27
+ dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)
28
+
29
+ def tokens_to_ids(samples: dict) -> Dict[str, str]:
30
+
31
+ eos_token_id = tokenizer.eos_token_id
32
+
33
+ batch_prompt = samples['prompt']
34
+ batch_response = samples['response']
35
+
36
+ encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False)
37
+ encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False)
38
+
39
+ # vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
40
+ input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
41
+ labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
42
+
43
+ return {
44
+ 'input_ids': input_ids,
45
+ 'labels': labels,
46
+ }
47
+
48
+ dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
49
+
50
+ return dataset
51
+
52
+ def sft_train(config: SFTconfig) -> None:
53
+
54
+ # step 1. 加载tokenizer
55
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
56
+
57
+ # step 2. 加载预训练模型
58
+ model = None
59
+ if os.path.isdir(config.finetune_from_ckp_file):
60
+ # 传入文件夹则 from_pretrained
61
+ model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)
62
+ else:
63
+ # load_state_dict
64
+ t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
65
+ model = TextToTextModel(t5_config)
66
+ model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception
67
+
68
+ # Step 4: Load the dataset
69
+ dataset = get_dataset(file=config.sft_train_file, split="train", tokenizer=tokenizer)
70
+
71
+ # Step 5: Define the training arguments
72
+ # T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
73
+ # huggingface官网的sft工具适用于language model/LM模型
74
+ generation_config = GenerationConfig()
75
+ generation_config.remove_invalid_values = True
76
+ generation_config.eos_token_id = tokenizer.eos_token_id
77
+ generation_config.pad_token_id = tokenizer.pad_token_id
78
+ generation_config.decoder_start_token_id = tokenizer.pad_token_id
79
+ generation_config.max_new_tokens = 320
80
+ generation_config.repetition_penalty = 1.5
81
+ generation_config.num_beams = 1 # greedy search
82
+ generation_config.do_sample = False # greedy search
83
+
84
+ training_args = Seq2SeqTrainingArguments(
85
+ output_dir=config.output_dir,
86
+ per_device_train_batch_size=config.batch_size,
87
+ auto_find_batch_size=True, # 防止OOM
88
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
89
+ learning_rate=config.learning_rate,
90
+ logging_steps=config.logging_steps,
91
+ num_train_epochs=config.num_train_epochs,
92
+ optim="adafactor",
93
+ report_to='tensorboard',
94
+ log_level='info',
95
+ save_steps=config.save_steps,
96
+ save_total_limit=3,
97
+ fp16=config.fp16,
98
+ logging_first_step=config.logging_first_step,
99
+ warmup_steps=config.warmup_steps,
100
+ seed=config.seed,
101
+ generation_config=generation_config,
102
+ )
103
+
104
+ # step 6: init a collator
105
+ collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
106
+ empty_cuda_cahce = MyTrainerCallback()
107
+
108
+ # Step 7: Define the Trainer
109
+ trainer = Seq2SeqTrainer(
110
+ model=model,
111
+ args=training_args,
112
+ train_dataset=dataset,
113
+ tokenizer=tokenizer,
114
+ data_collator=collator,
115
+ callbacks=[empty_cuda_cahce]
116
+ )
117
+
118
+ # step 8: train
119
+ trainer.train(
120
+ # resume_from_checkpoint=True
121
+ )
122
+
123
+ loss_log = pd.DataFrame(trainer.state.log_history)
124
+ log_dir = './logs'
125
+ if not os.path.exists(log_dir):
126
+ os.mkdir(log_dir)
127
+ loss_log.to_csv(f"{log_dir}/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
128
+
129
+ # Step 9: Save the model
130
+ trainer.save_model(config.output_dir)
131
+
132
+ if __name__ == '__main__':
133
+ config = SFTconfig()
134
+ sft_train(config)
train.ipynb ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from accelerate import notebook_launcher\n",
10
+ "import torch\n",
11
+ "\n",
12
+ "from model.trainer import ChatTrainer\n",
13
+ "from config import TrainConfig, T5ModelConfig"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "train_config = TrainConfig()\n",
23
+ "model_config = T5ModelConfig()\n",
24
+ "\n",
25
+ "print(train_config)\n",
26
+ "print(model_config)\n",
27
+ "\n",
28
+ "gpu_count = torch.cuda.device_count()\n",
29
+ "print('gpu device count: {}'.format(gpu_count))\n",
30
+ "\n",
31
+ "chat_trainer = ChatTrainer(train_config=train_config, model_config=model_config)"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "train = chat_trainer.train\n",
41
+ "\n",
42
+ "# chat_trainer.train() args: is_keep_training: bool, is_finetune: bool\n",
43
+ "train_args = (False, False)\n",
44
+ "\n",
45
+ "# 使用notebook_launcher函数启动多卡训练\n",
46
+ "notebook_launcher(train, num_processes=gpu_count, args=train_args, mixed_precision=train_config.mixed_precision)"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "test = chat_trainer.test\n",
56
+ "notebook_launcher(test, num_processes=gpu_count, mixed_precision=train_config.mixed_precision)"
57
+ ]
58
+ }
59
+ ],
60
+ "metadata": {
61
+ "kernelspec": {
62
+ "display_name": "Python 3 (ipykernel)",
63
+ "language": "python",
64
+ "name": "python3"
65
+ },
66
+ "language_info": {
67
+ "codemirror_mode": {
68
+ "name": "ipython",
69
+ "version": 3
70
+ },
71
+ "file_extension": ".py",
72
+ "mimetype": "text/x-python",
73
+ "name": "python",
74
+ "nbconvert_exporter": "python",
75
+ "pygments_lexer": "ipython3",
76
+ "version": "3.10.12"
77
+ },
78
+ "orig_nbformat": 4
79
+ },
80
+ "nbformat": 4,
81
+ "nbformat_minor": 2
82
+ }