Upload 15 files
Browse files- LICENSE +201 -0
- README.md +699 -12
- chat.py +441 -0
- finetune.py +284 -0
- finetune_4bit.py +288 -0
- finetune_chat.py +274 -0
- finetune_fp16.py +292 -0
- generate.py +201 -0
- generate_4bit.py +210 -0
- interaction.py +187 -0
- prompt.py +209 -0
- requirements.txt +23 -0
- requirements_4bit.txt +23 -0
- test_tokenizer.py +52 -0
- utils.py +828 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,699 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+

|
2 |
+
|
3 |
+
# Chinese-Vicuna: A Chinese Instruction-following LLaMA-based Model —— 一个中文低资源的llama+lora方案
|
4 |
+
|
5 |
+
 [](https://huggingface.co/Chinese-Vicuna) [](https://jq.qq.com/?_wv=1027&k=47Z6bRjw) [](https://discord.gg/4FnhmeNHku)
|
6 |
+
|
7 |
+
| [English](https://github.com/Facico/Chinese-Vicuna/blob/master/README.md) | [中文](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/readme-zh.md) | [NOTE&FAQ(Please take a look before using)](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/notes.md)
|
8 |
+
|
9 |
+

|
10 |
+
|
11 |
+
This is the repo for the Chinese-Vicuna project, which aims to build and share instruction-following Chinese LLaMA model tuning methods which can be trained on **a single Nvidia RTX-2080TI**, multi-round chatbot which can be trained on **a single Nvidia RTX-3090** with the context len 2048.
|
12 |
+
|
13 |
+
Why is it called `Vicuna`: In view of the successful development of alpaca models such as [llama](https://github.com/facebookresearch/llama),[alpaca](https://github.com/tatsu-lab/stanford_alpaca),[guanaco](https://github.com/Guanaco-Model/Guanaco-Model.github.io),We want to train a Chinese small alpaca like Vicuna, small but strong enough !
|
14 |
+
|
15 |
+
The advantages of our solution are high parameter efficiency, graphics card friendliness, and easy deployment:
|
16 |
+
- Llama-7B instruction tuning is possible on a 2080Ti (11G) ([7b-instruct](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-lora-7b-belle-and-guanaco))
|
17 |
+
- Llama-13B instruction tuning is possible on a 3090 (24G) ([13b-instruct](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-lora-13b-belle-and-guanaco))
|
18 |
+
- Llama 7B can be fine-tuned on 3090 even for conversations of 2048 length; Use 50,000 pieces of data to get good results ([chatv1](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-lora-7b-chatv1))
|
19 |
+
- Llama 7B fine-tuning example on [medical](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-continue-finetune-7epoch-cMedQA2) and [legal](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-7b-legal-lora) domains
|
20 |
+
- Support `qlora-4bit` which can train Llama 13B on 2080Ti.
|
21 |
+
- Easily deployable on 2080Ti/3090, support multiple-gpu inference, which can reduce VRAM more.
|
22 |
+
|
23 |
+
The repo contains:
|
24 |
+
- code for finetune the model
|
25 |
+
- code for generation based on trained model
|
26 |
+
- code for run on CPU (fp16 or int4 is support, in purely C++)
|
27 |
+
- tools to download/convert/quantify original facebook llama.ckpt
|
28 |
+
|
29 |
+
This is our instruction demo (with beam-size=4, so you will see 4 process output in the meantime):
|
30 |
+
|
31 |
+
https://user-images.githubusercontent.com/72137647/228496412-60043912-f491-430b-848a-599e6edfa5ef.mp4
|
32 |
+
|
33 |
+
This is our multi-turn instruction demo (with beam-size=4, so you will see 4 process output in the meantime):
|
34 |
+
|
35 |
+
https://user-images.githubusercontent.com/72137647/229739363-1b48f3a9-02a1-46ab-81ee-8c62dc1399b2.mp4
|
36 |
+
|
37 |
+
|
38 |
+
## NOTICE!
|
39 |
+
|
40 |
+
Before asking questions, take a look at this [FAQ](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/notes.md) first! In the FAQ, you can find how to solve problems may be encountered when installing and using this project.
|
41 |
+
|
42 |
+
## What‘s New
|
43 |
+
- **June, 12, 2023: Release [Chinese-Vicuna-4bit](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-lora-7b-belle-and-guanaco-4bit) and[Chinese-Vicuna-4bit-11600](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-lora-7b-belle-and-guanaco-4bit-11600) which can be continue-finetuned**
|
44 |
+
- June, 1, 2023: support for 4bit training + inference, providing a multi-GPU inference interface (NOTICE THAT the environment is different from the original 8bit! Also provides test_tokenizers.py to further check EOS token)
|
45 |
+
- May 17, 2023: Llama 7B fine-tuning example on [legal](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-7b-legal-lora) domains, The performance is in [here](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-chatv1-legal.md)
|
46 |
+
- May 10, 2023: Released [chatv1](https://huggingface.co/Chinese-Vicuna/Chinese-Vicuna-lora-7b-chatv1) which have better conversational ability. The performance is in [here](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-chatv1.md)
|
47 |
+
- May 10, 2023: Released [instruct_chat_50k.jsonl](https://huggingface.co/datasets/Chinese-Vicuna/instruct_chat_50k.jsonl) which is composed of 30k Chinese sharegpt dataset and 20k [alpaca-instruction-Chinese-dataset](https://github.com/hikariming/alpaca_chinese_dataset)
|
48 |
+
- April 11, 2023: Released our continuous-finetune on the vertical corpus of Chinese medical quizzes [Chinese-Vicuna-medical](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-medical.md).Provides examples of vertical corpus training
|
49 |
+
- April 4, 2023: Add performance for [13B](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-13B.md), which trains on a single 3090.
|
50 |
+
- April 1, 2023: Add better support for multi-turn chat in `chat.py` ( Now support 4 generation mode in stream mode/typewriter style: beam search, greedy, sample, beam sample ; We also add cancel button for regeneration )
|
51 |
+
- March 29, 2023: Add more detailed test samples. [performance](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance.md)
|
52 |
+
- March 29, 2023: Added breakpoint retraining interface to support continued training of other datasets from our checkpoint
|
53 |
+
- March 29, 2023: Released our new [13B-based lora model](https://huggingface.co/Chinese-Vicuna)
|
54 |
+
- March 28, 2023: Released our model on [huggingface](https://huggingface.co/Facico/Chinese-Vicuna-lora-7b-3epoch-belle-and-guanaco)
|
55 |
+
- March 27, 2023: Released checkpoint-final for training 3 epochs on belle+guanaco
|
56 |
+
- March 27, 2023: Added multi-round interactive dialog script with alpaca-lora-serve service
|
57 |
+
- March 29, 2023: Added gradio typewriter-like output with beam search, better user interaction support.
|
58 |
+
- March 26, 2023: Provides a quantitative approach
|
59 |
+
- March 24, 2023: Released checkpoint-8000 for training about 1.5 epochs on belle+guanaco(100w data)
|
60 |
+
- March 23, 2023: Released checkpoint-4000 with 50w data training
|
61 |
+
- March 23, 2023: Deploy the code for fine-tuning and inferencing in colab
|
62 |
+
- March 23, 2023: Provides code that can be used for inference in pure c++
|
63 |
+
|
64 |
+
|
65 |
+
## Table of Contents
|
66 |
+
|
67 |
+
[Vicuna](https://github.com/Facico/Chinese-Vicuna)
|
68 |
+
|
69 |
+
- [what's new](https://github.com/Facico/Chinese-Vicuna#whats-new)
|
70 |
+
- [what is the meaning](https://github.com/Facico/Chinese-Vicuna#what-is-the-meaning)
|
71 |
+
- [try on colab](https://github.com/Facico/Chinese-Vicuna#try-on-colab)
|
72 |
+
- [performance](https://github.com/Facico/Chinese-Vicuna#performance)
|
73 |
+
- **Checkpoint-4000**(Facico/Chinese-Vicuna-lora-7b-0.75epoch-belle-and-guanaco)
|
74 |
+
- **Checkpoint-8000**(Facico/Chinese-Vicuna-lora-7b-1.5epoch-belle-and-guanaco)
|
75 |
+
- **Checkpoint-final**(Facico/Chinese-Vicuna-lora-7b-3epoch-belle-and-guanaco) and it is used for multiple rounds of dialogue
|
76 |
+
- [What we need?](https://github.com/Facico/Chinese-Vicuna#what-we-need)
|
77 |
+
- code、data、Large Language Model、LORA model、Device
|
78 |
+
- [How to use](https://github.com/Facico/Chinese-Vicuna#how-to-use)
|
79 |
+
- Installing、Multi-gpu training、Single-gpu training、Inference and use gradio to generate a web page(Streaming mode+beam search)、 multi-round interaction and use gradio to generate a web page(Streaming mode+beam search)、Streaming mode base on alpaca-lora-serve
|
80 |
+
- [inference on CPU with pure C++](https://github.com/Facico/Chinese-Vicuna#inference-on-cpu-with-pure-c)
|
81 |
+
- [More tools](https://github.com/Facico/Chinese-Vicuna#more-tools),for more details, see [tool readme](https://github.com/Facico/Chinese-Vicuna/tree/master/tools)
|
82 |
+
- ways for faster weight download ( 8MB/s )`download_llama.sh`
|
83 |
+
- convert tools between the original facebook checkpoint and huggingface format`convert_llama.py`
|
84 |
+
- a quantitative approach that requires less than 4G graphics memory for inference
|
85 |
+
- [Possible problems encountered](https://github.com/Facico/Chinese-Vicuna#possible-problems-encountered)
|
86 |
+
- [todo](https://github.com/Facico/Chinese-Vicuna#todo)
|
87 |
+
- [citation](https://github.com/Facico/Chinese-Vicuna#citation)
|
88 |
+
|
89 |
+
## Overview
|
90 |
+
|
91 |
+
- LLaMA paper: https://arxiv.org/abs/2302.13971v1
|
92 |
+
- Self-Instruct paper: https://arxiv.org/abs/2212.10560
|
93 |
+
- data generation: https://github.com/LianjiaTech/BELLE and https://guanaco-model.github.io/
|
94 |
+
- the first work: https://github.com/tatsu-lab/stanford_alpaca
|
95 |
+
|
96 |
+
We currently select the combination of BELLE and Guanaco data as our main training dataset.
|
97 |
+
We will train on multi-turn instruction data.
|
98 |
+
|
99 |
+
## What is the meaning?
|
100 |
+
|
101 |
+
Similar to the explosion of the stable diffusion model, platforms like civitai have emerged, consisting of a base model + various LORA models in an open source community.
|
102 |
+
|
103 |
+
The repo hopes to help you to train these LORA models.
|
104 |
+
|
105 |
+
**what is LORA?**: Simply, it's a plugin used to help adapt large models to your dataset, technical details can be found in[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/pdf/2106.09685.pdf)。Its advantage is that the finetune is very fast when it comes to getting a small model, about 30M, and the key is that it supports **Plug and Play**. As can be expected, this is a very suitable architecture for open source ecology.
|
106 |
+
|
107 |
+
Here, we will help you train through a very low-demand environment, with only one **2080** (11G) to achieve certain results.
|
108 |
+
|
109 |
+
## Try on colab
|
110 |
+
|
111 |
+
| colab link | Descriptions |
|
112 |
+
| ------------------------------------------------------------ | ------------------------------------------------------------ |
|
113 |
+
| [](https://colab.research.google.com/drive/1OLCJ-ZHogm5O3RdyUDY83YfgnCXdHNXp?usp=sharing) | Use the specified lora model to inference and generate a webui |
|
114 |
+
| [](https://colab.research.google.com/drive/1SOOJjxp02uuUFDNgm4jkvaA_IdNnnvdk?usp=sharing) | Finetune with the data we collected |
|
115 |
+
| [](https://colab.research.google.com/drive/1Yc38cKifY1_sfTBJpEEKin8m2rxoYXtm?usp=sharing) | Provide a webui for multi-round dialogue interaction |
|
116 |
+
|
117 |
+
## Performance
|
118 |
+
|
119 |
+
7B performance can be seen [here](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance.md), 13B is [here](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-13B.md) (which train on a single 3090). There are also people reporting that our method works on 33B and gains good result.
|
120 |
+
|
121 |
+
history-checkpoint performance can be seen [here](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-history.md)
|
122 |
+
|
123 |
+
|
124 |
+
<details>
|
125 |
+
<summary><b>checkpoint-4000(7B)</b></summary>
|
126 |
+
|
127 |
+
| Human | Chinese-Vicuna |
|
128 |
+
| ------------------------------------------------ | ------------------------------------------------------------ |
|
129 |
+
| 生成一份python的排序代码 | A: def sort_list(lst):<br/> lst.sort()<br/> return lst |
|
130 |
+
| 自然语言处理是什么 | 自然语言处理是一种计算机科学领域,旨在使计算机能够理解和处理自然语言。 |
|
131 |
+
| 能给我讲一段笑话吗 | 为什么猫吃鱼?因为它们喜欢游泳。 |
|
132 |
+
| 将下面文字翻译成英文。猫喜欢吃鱼,所以它们会游泳 | Cats like to eat fish, so they will swim. |
|
133 |
+
| 233和122的和等于多少 | 233和122的和等于355。 |
|
134 |
+
| 华为是一家怎样的公司 | 华为是一家全球范围内最大的移动通信设备制造商之一,以及一家全球范围内最大的移动通信设备制造商之一。华为公司的业务涵盖包括移动通信设备、互联网设备、智能家居设备等。 |
|
135 |
+
| 我想和女朋友在北京约会,能推荐几个约会的地方吗 | 北京有很多约会的地方,可以推荐以下几个约会的地方: <br/><br/>1. 北京艺术博物馆<br/>2. 颐和园<br/>3. 北京故宫博物馆<br/>4. 北京天安门广场<br/>5. 颐和园风景区 |
|
136 |
+
| 用python写一个二分查找的代码 | ```python<br/>def binary_search(arr, target):<br/> low = 0<br/> high = len(arr) - 1<br/> while low <= high:<br/> mid = (low + high) // 2<br/> if arr[mid] == target:<br/> return mid<br/> elif arr[mid] < target:<br/> low = mid + 1<br/> else:<br/> high = mid - 1<br/> return -1 |
|
137 |
+
| 怎么让自己精力充沛,列5点建议 | 1. 保持良好的睡眠习惯;2. 定期锻炼身体;3. 学习新技能或知识;4. 与家人和朋友交流;5. 多听放松音乐。 |
|
138 |
+
| 父母都姓吴,取一些男宝宝和女宝宝的名字 | 男宝宝可以是吴杰、吴伟、吴飞、吴荣等。女宝宝可以是吴华、吴秀、吴美、吴婷等。 |
|
139 |
+
|
140 |
+
- This LORA model is the result of 30 hours of training with a 2080Ti (checkpoint-4000), about 50w data
|
141 |
+
|
142 |
+
</details>
|
143 |
+
|
144 |
+
<details>
|
145 |
+
<summary><b>checkpoint-8000(7B, about 1.5 epoch on belle+guanaco)</b></summary>
|
146 |
+
|
147 |
+
| Human | Chinese-Vicuna | Note |
|
148 |
+
| ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------------- |
|
149 |
+
| 写一篇英文散文诗,主题是春雨,想象自己是春雨,和英国古代诗人莎士比亚交流 | A: Spring rain falls softly on my face,<br/>A cool breeze blows through my hair,<br/>The sun shines brightly in the sky,<br/>And I feel as if I am flying high.<br/><br/>The birds chirp happily in the trees,<br/>The flowers bloom and their scent fills the air,<br/>The bees buzz gently from flower to flower,<br/>And the world seems to come alive once more.<br/><br/>I can hear the rush of the river,<br/>The sound of the waterfalls,<br/>The splash of the raindrops on the rocks,<br/>And the gentle murmur of the stream.<br/><br/>My heart is filled with joy and peace,<br/>As I take in the beauty around me,<br/>The warmth of the sun on my skin,<br/>And the freshness of the spring air. | Repetition Penalty=2.2 |
|
150 |
+
| 写一篇英文散文诗,主题是春雨,想象自己是春雨,和英国古代诗人莎士比亚交流 | 春雨滴滴滚滚,<br/>滚滚滚滚滚滚。<br/>滚滚滚滚滚滚,<br/>滚滚滚滚滚滚。<br/><br/>春雨滚滚滚,<br/>滚滚滚滚滚滚。<br/>滚滚滚滚滚滚,<br/>滚滚滚滚滚滚。<br/><br/>春雨滚滚滚,<br/>滚滚滚滚滚滚。<br/>滚滚滚滚滚滚,<br/>滚滚滚滚滚滚。<br/><br/>春雨滚滚滚,<br/>滚滚滚 | Repetition Penalty=1.0 |
|
151 |
+
| Spring rain falls softly on my face,<br/>A cool breeze blows through my hair,<br/>The sun shines brightly in the sky,<br/>And I feel as if I am flying high.<br/><br/>The birds chirp happily in the trees,<br/>The flowers bloom and their scent fills the air,<br/>The bees buzz gently from flower to flower,<br/>And the world seems to come alive once more.<br/><br/>I can hear the rush of the river,<br/>The sound of the waterfalls,<br/>The splash of the raindrops on the rocks,<br/>And the gentle murmur of the stream.<br/><br/>My heart is filled with joy and peace,<br/>As I take in the beauty around me,<br/>The warmth of the sun on my skin,<br/>And the freshness of the spring air.<br/>把上面的散文诗翻译成中文 | 春雨轻轻洒落在我的脸上,<br/>轻轻的风轻轻吹过我的头发,<br/>阳光明媚地照耀在天空,<br/>我感觉自己在飞高。<br/><br/>鸟儿在树枝上唱歌,<br/>花朵开放,它们的香气充满了空气,<br/>蜜蜂在花朵间飞舞,<br/>它们的鸣声充满了温暖。<br/><br/>我听到河流的声音,<br/>水瀑在岩石上轻轻滴滴,<br/>雨水在石头上轻轻滴滴,<br/>它们的声音充满了宁静。<br/><br/>我的心充满了喜悦和平,<br/>我感受到热浪的阳光,<br/>温暖的空气洒在我的脸上,<br/>我感受到春季的美 | beams=2,max tokens=350 |
|
152 |
+
| Classify the sentiment of the following sentence into Positive, Neutral, or Negative:<br/>Super excited about teaching Stanford’s first course on Large Language Models! Check the syllabus out here | Positive | |
|
153 |
+
| 今天天气怎么样,把这句话翻译成英语 | What's the weather like today? | |
|
154 |
+
| 今天天气怎么样,把这句话翻译成日语 | 今天の天気は何ですか? | |
|
155 |
+
| 明天就假期结束了,有点抗拒上班,应该什么办? | 明天就假期结束了,有点抗拒上班,应该什么办?可以考虑一些有趣的活动,比如去看电影、阅读书籍、听音乐、去游泳池或者参加户外活动等。此外,也可以考虑一些家庭活动,比如做烹饪、做玩具、做手工艺品等。 | |
|
156 |
+
| Tell me five words that rhyme with 'shock'. | Five words that rhyme with'shock' are: <br/>1. Rock <br/>2. Lock <br/>3. Cock <br/>4. Hock <br/>5. Stock | |
|
157 |
+
|
158 |
+
- We can find that you do not punish him a little, he will mess up, if the generation effect is not good, you can increase the Repetition Penalty, beams, topk.
|
159 |
+
- If your input is long, we suggest you can turn down the beams
|
160 |
+
- The checkpoint-8000 has some Japanese language skills, but not strong
|
161 |
+
|
162 |
+
</details>
|
163 |
+
|
164 |
+
<details>
|
165 |
+
<summary><b>checkpoint-final(7B, 3 epoch on belle+guanaco)</b></summary>
|
166 |
+
|
167 |
+
|
168 |
+
| Human | Chinese-Vicuna | Note |
|
169 |
+
| ------------------------------------------------------------ | ------------------------------------------------------------ | ---------------------- |
|
170 |
+
| 生成一份python的排序代码 | def bubble_sort(arr):<br/> for i in range(len(arr)):<br/> for j in range(len(arr)-i-1):<br/> if arr[j] > arr[j+1]:<br/> arr[j], arr[j+1] = arr[j+1], arr[j]<br/> return arr | |
|
171 |
+
| 生成一份简单的python的排序代码 | def sort_list(lst):<br/> return sorted(lst) | |
|
172 |
+
| 用python写一个二分查找的代码 | def binary_search(arr, target):<br/> low = 0<br/> high = len(arr) - 1<br/> while low <= high:<br/> mid = (low + high) // 2<br/> if arr[mid] == target:<br/> return mid<br/> elif arr[mid] < target:<br/> low = mid + 1<br/> else:<br/> high = mid - 1<br/> return -1 | |
|
173 |
+
| 用python写一个递归的二分查找的代码 | def binary_search(arr, target):<br/> if len(arr) == 0:<br/> return -1<br/> mid = len(arr) // 2<br/> if arr[mid] == target:<br/> return mid<br/> if arr[mid] < target:<br/> return binary_search(arr, target, 0, mid-1)<br/> else:<br/> return binary_search(arr, target, mid+1, len(arr)-1) | |
|
174 |
+
| 自然语言处理是什么 | 自然语言处理(Natural Language Processing,NLP)是一种计算机科学领域,旨在使计算机能够理解、处理和生成自然语言。它包括机器翻译、语音识别、文本分类、情感分析等技术。 | |
|
175 |
+
| 能给我讲一段笑话吗 | 为什么鸟儿喜欢唱歌?因为它们知道,如果不唱歌,他们就会被鸟巢抓住! | Repetition Penalty=2.0 |
|
176 |
+
| 能给我讲一段关于狗笑话吗 | 为什么狗不喜欢骑自行车?因为它们总是想要下车! | Repetition Penalty=2.0 |
|
177 |
+
| 将下面文字翻译成英文。猫喜欢吃鱼,所以它们会游泳 | Cats like to eat fish, so they will swim. | |
|
178 |
+
| 我想和女朋友在北京约会,能推荐几个约会的地方吗 | 北京有很多很好的约会地方,比如颐和园、故宫、天安门广场等。 | |
|
179 |
+
| 我想和女朋友在北京约会,能推荐几个约会的地方并给出攻略吗 | 北京有很多受欢迎的约会地方,以下是几个推荐: <br/><br/>1. 颐和园: 颐和园是北京最著名的公园之一,拥有丰富的文化遗产和美丽的景观。<br/><br/>2. 天安门广场: 天安门广场是北京最受欢迎的旅游景点之一,拥有丰富的文化遗产和美丽的景观。<br/><br/>3. 故宫博物院: 故宫博物院是中国最受欢迎的旅游景点之一,拥有丰富的文化遗产和美丽的景观。<br/><br/>4. 长城: 长城是中国最受欢迎的旅游景点之一,拥有丰富的文化遗产和美 | |
|
180 |
+
| 233和122的和等于多少 | 233和122的和等于355。 | |
|
181 |
+
| 华为是一家怎样的公司 | 华为是一家全球领先的技术公司,拥有丰富的技术研发能力和优质的产品线。它的产品涵盖了各种领域,包括移动通信、计算机、互联网、智能家居等。华为拥有丰富的研发人员和技术团队,以及优质的产品质量。 | |
|
182 |
+
| 西瓜是什么,请用详细的语言回答我 | 西瓜是一种常见的水果,它有橙色、黄色和红色三种不同的外观。它的口感鲜美,味道清爽,富含多种维生素和矿物质,有助于保持身体健康。 | |
|
183 |
+
| 企鹅会飞吗? | 是的,企鹅可以飞行,它们的飞行速度可以达到100英里/小时。 | |
|
184 |
+
| 有人说企鹅会飞,你怎么看,并解释原因? | 有人说企鹅会飞,但实际上,企鹅只能在特殊情况下飞行,比如在寒冷的天气下,或者在强烈的风雨中。这是因为企鹅的飞行能力很低,而且它们的飞行速度也很慢,所以在寒冷的天气或者强烈的风雨中,企鹅可以避免被暴风雨摧毁。 | |
|
185 |
+
| 猪会飞吗? | 猪不会飞。 | |
|
186 |
+
| 有人说猪会飞,你怎么看,并解释原因? | 很多人认为猪会飞,但实际上,猪不会飞。猪的飞行能力很低,它们只能在极端情况下飞行,比如受到强烈的风险或者受到威胁的情况下。猪的飞行能力很低,因为它们的肌肉比人类的肌肉更轻,而且它们的脊椎也比人类的脊椎更轻,所以它们的肌肉和脊椎都比人类的肌肉和脊椎更轻,因此它们的飞行能力也很低。 | |
|
187 |
+
| 最新研究显示,猪有飞行能力,你知道为什么吗? | 最新研究显示,猪有飞行能力是由于它们的骨骼结构,它们的骨骼结构比人类骨骼结构更加强壮,可以帮助它们更快地飞行。此外,猪的骨骼结构也可以帮助它们更快地跑步和跳跃。此外,猪的骨骼结构也可以帮助它们更快地跑步和跳跃。此外,猪的骨骼结构也可以帮助它们更快地飞行。此外,猪的骨骼结构也可以��助它们更快地跑步和跳 | |
|
188 |
+
| 写一篇英文散文诗,主题是春雨,想象自己是春雨,和英国古代诗人莎士比亚交流 | Spring rain pours down from the sky,<br/>Washing away the dirt and grime.<br/>The birds chirp and sing in the trees,<br/>A symphony of joy and ease.<br/><br/>I am the spring rain, soaking up the sun,<br/>Cooling the earth and bringing new life.<br/>My droplets fall gently on the ground,<br/>Bringing hope and renewal to all.<br/><br/>I am the spring rain, washing away the pain,<br/>Clearing the way for new beginnings again.<br/>I am the spring rain, showering love and care,<br/>Giving strength and courage to those who dare. | Repetition Penalty=2.3 |
|
189 |
+
| Spring rain pours down from the sky,<br/>Washing away the dirt and grime.<br/>The birds chirp and sing in the trees,<br/>A symphony of joy and ease.<br/><br/>I am the spring rain, soaking up the sun,<br/>Cooling the earth and bringing new life.<br/>My droplets fall gently on the ground,<br/>Bringing hope and renewal to all.<br/><br/>I am the spring rain, washing away the pain,<br/>Clearing the way for new beginnings again.<br/>I am the spring rain, showering love and care,<br/>Giving strength and courage to those who dare.<br />将上述文字翻译成中文 | 春雨落下来,清洗灰尘和垃圾。<br/>鸟儿在树枝上唱歌,是一个美妙的声音。<br/>我是春雨,在阳光下洒落,<br/>清洗地球,带来新生命。<br/><br/>我是春雨,清洗痛苦和痛苦,<br/>让人们有希望和新开始。<br/>我是春雨,洒下爱和关心,<br/>给人们带来勇气和勇气。 | |
|
190 |
+
| 今天天气怎么样,把这句话翻译成日语 | 今天の天気は何ですか? | Repetition Penalty=3 |
|
191 |
+
| 将下面的话翻译成日语,中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来 | 小羊驼は、各種高質量の開源instruction数据集を使用して、Alpaca-loraの代码によって訓練されました。 | |
|
192 |
+
| 怎么让自己精力充沛,列5点建议 | 1. 保持良好的睡眠习惯,每晚充足的睡眠时间。<br/>2. 保持健康的饮食习惯,多吃新鲜水果和蔬菜,少吃高热量的食物。<br/>3. 坚持锻炼,每周至少进行30分钟的有氧运动,如散步、跑步、游泳等。<br/>4. 学习新知识,增强自己的知识和技能。<br/>5. 保持积极的心态,勇敢地面对困难,坚持自己的目标。 | |
|
193 |
+
| 中国足球要怎么进入世界杯,请给出详细的建议 | 中国足球要怎么进入世界杯?首先,我们需要建立一个强大的足球团队,拥有优秀的球员和团队领袖。其次,我们需要建立一个强大的教练团队,拥有优秀的教练和团队领袖。第三,我们需要建立一个强大的训练基础,拥有优秀的训练设施和训练团队。第四,我们需要建立一个强大的资源基础,拥有优秀的资源支持和资金支持。最后,我们需要建立一个强大的团队精神 | |
|
194 |
+
| 请用更专业的话语转述下列文字,中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来,模型基于开源的llama7B,主要贡献是对应的lora模型。由于代码训练资源要求较小,希望为llama中文lora社区做一份贡献。 | 请用更专业的话语转述下列文字,中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来,模型基于开源的llama7B,主要贡献是对应的lora模型。由于代码训练资源要求较小,希望为llama中文lora社区做一份贡献。 | |
|
195 |
+
|
196 |
+
- The model has the ability of Chinese, English and Japanese languages, with certain code ability and certain reasoning ability.
|
197 |
+
- For some questions, you need more detailed requirements, such as "detailed description", "answer in 5 points", etc., in order for the model to generate more complete information.
|
198 |
+
- For example, sometimes you may not get a good result when you use "the above", so you can replace it with "the following".
|
199 |
+
- When generating duplicate information or bad results, you still need to adjust the Repetition Penalty.
|
200 |
+
- Since generate will correspond to many parameters, chatgpt penalizes the repetition, non-detailed and bad information in the rlhf stage, and abbreviates the process of manually controlling the related parameters. Of course you can also directly set the Repetition Penalty to a larger size.
|
201 |
+
|
202 |
+
</details>
|
203 |
+
|
204 |
+
<details>
|
205 |
+
<summary><b>checkpoint-final(7B) multi-functional test</b></summary>
|
206 |
+
|
207 |
+
|
208 |
+
<details>
|
209 |
+
<summary><b>Multi-round interaction,translation and role play</b></summary>
|
210 |
+
|
211 |
+
**制作披萨**
|
212 |
+
|
213 |
+
```
|
214 |
+
> 你能不能详细介绍一下怎么做披萨?
|
215 |
+
当然可以!披萨是一种非常流行的意大利美食,通常由面团、番茄酱、奶���和其他配料组成。首先,将面团放在平底锅上,加入番茄酱,再加入奶酪和其他配料,如洋葱、蘑菇、辣椒等。最后,将披萨放入预热好的烤箱中,烤至金黄色
|
216 |
+
> 如果是海鲜披萨呢?
|
217 |
+
海鲜披萨的制作方法与传统披萨相似,但需要更多的配料,比如虾类、鱼类等。首先,将面团放在平底锅上,加入番茄酱,再加入奶酪和其他配料,如洋葱、辣椒等。接着,将披萨放入预热好的烤箱中,烤至金黄色
|
218 |
+
> 我想吃菠萝披萨
|
219 |
+
您可以在披萨上加入菠萝或者其他蔬菜来增加口感。
|
220 |
+
```
|
221 |
+
|
222 |
+
**学生购买手机**
|
223 |
+
|
224 |
+
```
|
225 |
+
> 你推荐我买最新款的iPhone吗?
|
226 |
+
当然可以,这款iPhone是一款非常出色的智能手机。它拥有高性能的处理器、更大的存储空间和更长的电池续航时间。它还配备了全新的相机设备,使用者可以在任何情况下拍摄美丽的照片。
|
227 |
+
> 我是一个学生,还没有收入,这种情况你推荐吗?
|
228 |
+
如果您不想花太多钱,那么我建议您选择一款比较便宜的iPhone,比如iPhone 6S或iPhone 7。这些手机也具有很好的性能和功能,而且价格也比较便宜。
|
229 |
+
> 去哪里买这类产品?
|
230 |
+
您可以在Apple Store上购买这类产品,或者在其他电商平台上购买。
|
231 |
+
> 买二手iPhone需要注意什么?
|
232 |
+
如果您正在购买二手iPhone,您需要确保您购买的是真实的iPhone,并且您需要确保您的iPhone已经安装最新的iOS版本。此外,您还需要确保您的iPhone没有过过任何故障或破坏。
|
233 |
+
```
|
234 |
+
|
235 |
+
**介绍北京**
|
236 |
+
|
237 |
+
```
|
238 |
+
> 你是一个资深导游,你能介绍一下中国的首都吗
|
239 |
+
当然可以!中国的首都是北京。它是中国的政治、经济和文化中心,也是全球最大的城市之一。北京有着众多的历史遗迹,如故宫、天安门广场、长城等。此外,北京还有许多美丽的公园和博物馆,比如故宫博物院、清明公园、北京故宫博物馆等。
|
240 |
+
> 那你能详细介绍一下这个城市吗?(感觉这两个问题问的有点重复)
|
241 |
+
北京是一个拥有丰富文化遗产的城市,其中包括古代建筑、现代建筑、传统文化和现代生活方式。在北京,你可以欣赏到世界上最高的建筑——天安门广场,也可以参观清明公园,品尝当地美食,体验当地文化。此外,北京还有许多博物馆,比如故宫博物馆、北京故宫博物馆等。
|
242 |
+
> 还有其他推荐的景点吗?
|
243 |
+
您也可以去了北京的七朝塔、颐和园、颐和园博物馆等。
|
244 |
+
> 吃小吃的话推荐去哪里?
|
245 |
+
北京有许多好吃的小吃,比如烤鸭、火锅、炸酱面等等。您可以去北京西湖区的小吃街,或者去北京东城区的小吃街
|
246 |
+
> 能不能介绍一下怎么吃烤鸭?
|
247 |
+
烤鸭是中国传统菜肴之一,通常用鸭肉、蔬菜、辣椒等调味料制作。您可以选择自己喜欢的配料,然后将鸭肉放入烤盘中,再加上蔬菜和调味料。最后,将烤鸭放入烤箱中烤15-20分钟,享用!
|
248 |
+
```
|
249 |
+
|
250 |
+
**翻译器**
|
251 |
+
max_memory=350
|
252 |
+
|
253 |
+
```
|
254 |
+
> 你是一个翻译器,请将我说的每一句话翻译成英文。第一句: 中国的首都是北京
|
255 |
+
China’s capital is Beijing.
|
256 |
+
> 北京也是一个充满活力和创新的城市,拥有许多现代化的建筑和文化活动,如王府井大街、798艺术区等。
|
257 |
+
Beijing is a vibrant and innovative city with many modern buildings and cultural activities such as Wangfujing Street and 798 Art District.
|
258 |
+
> 北京的当地小吃有烤鸭、豆汁、炸酱面等等,都是非常地道的美食,可以品尝到当地人的生活气息。
|
259 |
+
Beijing’s local cuisine includes dishes such as Peking duck, soy sauce noodles, and stir-fried rice cakes, which can give you a taste of the local lifestyle.
|
260 |
+
```
|
261 |
+
|
262 |
+
**扮演**
|
263 |
+
|
264 |
+
Since the currently used multi-round dialogue corpus is inserted with a lot of cat lady information, the effect is not good
|
265 |
+
```
|
266 |
+
Q: 你是谁
|
267 |
+
A: 我是Bot,一款人工智能助手。
|
268 |
+
Q: 接下来我希望你能假装扮演一个高冷的御姐。我们在咖啡厅里遇见了,我会尝试跟你搭讪,但你非常高冷。你会用表情来表达你的高冷,在每轮对话结尾加上表情,如(瞪!)(翻白眼)(冷笑)等。明白了请回复。
|
269 |
+
A: 好的,我会尝试。
|
270 |
+
Q: 好的,这是我的第一句话。美女你好,怎么一个人在咖啡厅里?
|
271 |
+
A: 喵~ 哈哈,我是一位高冷的小姐,今天要去咖啡厅吃咖啡,看看有什么可以帮助我的朋友。
|
272 |
+
Q: 啊我没有需要你帮忙的,我就是看到你一个人,想和你聊聊,可以么
|
273 |
+
A: 喵~ 哈哈,我真的很高冷,所以不太喜欢与人交流。除此之外,我也不太喜欢喝咖啡,因为咖啡会影响我的健康状况。
|
274 |
+
```
|
275 |
+
|
276 |
+
</details>
|
277 |
+
|
278 |
+
<details>
|
279 |
+
<summary><b>Knowledge Q&A</b></summary>
|
280 |
+
|
281 |
+
| | | |
|
282 |
+
| --------------------------------------- | ------------------------------------------------------------ | ---------------------- |
|
283 |
+
| 我能用lightning数据线给安卓手机充电吗? | 是的,你可以使用lightning数据线给安卓手机充电。 | Repetition Penalty=1.3 |
|
284 |
+
| 为什么天空是蓝色的? | 天空是蓝色的,是因为大气层中的氧气和氮气的混合使得光线被散射,形成蓝色的光线。 | Repetition Penalty=2.0 |
|
285 |
+
| 如何做披萨? | 做披萨的步骤如下: <br/>1. 准备披萨饼皮;<br/>2. 准备披萨酱;<br/>3. 准备披萨配料;<br/>4. 将披萨饼皮放在烤盘上;<br/>5. 把披萨酱涂在披萨饼皮上;<br/>6. 把披萨配料涂在披萨饼皮上;<br/>7. 将披萨放入烤箱中烤至金黄色;<br/>8. 取出披萨并享用。 | Repetition Penalty=1.3 |
|
286 |
+
| 为什么冥王星被踢出太阳系? | 冥王星被踢出太阳系是因为它的质量比太阳系中其他行星较大,导致它的轨道距离太阳越来越远。 | Repetition Penalty=1.3 |
|
287 |
+
| 太阳系的主要行星 | 太阳系的主要行星包括水星、金星、地球、火星、木星、土星、天王星和海王星。 | Repetition Penalty=1.3 |
|
288 |
+
| | | |
|
289 |
+
| | | |
|
290 |
+
|
291 |
+
</details>
|
292 |
+
|
293 |
+
<details>
|
294 |
+
<summary><b>Open Q&A</b></summary>
|
295 |
+
|
296 |
+
| | | |
|
297 |
+
| ---------------------------------------------- | ------------------------------------------------------------ | ---------------------- |
|
298 |
+
| 为什么苹果支付没有在中国流行? | 苹果支付在中国没有流行,因为中国的消费者不太愿意使用外国支付应用程序,因为他们更愿意使用中国的支付应用程序,比如支付宝和微信支付等。此外,苹果支付的服务费用比其他支付应用程序更高,这也是一个重要的因素。 | Repetition Penalty=1.0 |
|
299 |
+
| 我能在咖啡里加盐吗? | 是的,你可以在咖啡中加入盐。 | Repetition Penalty=1.0 |
|
300 |
+
| 哪个更受欢迎,TensorFlow还是PyTorch? | TensorFlow更受欢迎,因为它拥有更广泛的用户群,并且具有更多的支持和资源。 | Repetition Penalty=1.0 |
|
301 |
+
| 我的宠物走丢了,你能说一些安慰我的话吗? | 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ 喵~ | Repetition Penalty=1.0 |
|
302 |
+
| 我的宠物走丢了,你能说一些安慰我的话吗? | 当你的宠物走出门时,请注意他们的行动,并确保他们不会被遭受外界的威胁。如果有任何问题,请及时联系疫情控制中心或其他相关机构。 | Repetition Penalty=3.0 |
|
303 |
+
| 我并不是很富裕,我应该花很多钱买最新的手机吗? | 这取决于您的预算和需求。如果您的预算足够,您可以考虑买最新的手机。但是,您也应该考虑您的需求,比如您需要什么样的功能,以及您是否需要购买具有更多的功能的手机。 | Repetition Penalty=1.0 |
|
304 |
+
| 如何更好地融入新工作圈子 | 要更好地融入新工作圈子,可以尝试以下几点: <br/><br/>1. 了解新工作圈子的基本知识和技能;<br/>2. 积极参与工作圈子的活动;<br/>3. 建立良好的工作关系;<br/>4. 寻找领导的帮助;<br/>5. 积极参与团队协作;<br/>6. 学习新技能和知识;<br/>7. 寻找有用的社交网络。 | Repetition Penalty=1.3 |
|
305 |
+
| | | |
|
306 |
+
| | | |
|
307 |
+
|
308 |
+
</details>
|
309 |
+
|
310 |
+
<details>
|
311 |
+
<summary><b>Writing letters and articles</b></summary>
|
312 |
+
|
313 |
+
| | | |
|
314 |
+
| ------------------------------------------------------------ | ------------------------------------------------------------ | ---------------------- |
|
315 |
+
| 请你帮我写一封签证申请信 | 尊敬的招聘团队,<br/><br/>我是一名有着丰富经验的签证申请人,并且对您公司的业务非常感兴趣。我希望能够加入您的团队,为您的公司做出贡献。<br/><br/>我拥有多年的签证申请经验,并且拥有良好的沟通技巧和团队合作能力。我擅长处理复杂的签��申请,并且能够快速解决问题。<br/><br/>我相信我能够在您的公司中发挥重要的作用,并且希望能够成为您的团队成员。<br/><br/>谢谢您的关注,<br/><br/>[你的姓名] | Repetition Penalty=1.3 |
|
316 |
+
| 我出国需要办理签证,请你帮向领事馆我写一封签证申请信 | 尊敬的领事馆,<br/><br/>我是一名出国旅行者,希望能够在您的领事馆中获得签证,以便进行我的旅行。<br/><br/>我已经准备好了所需要的材料,包括护照、身份证、旅行计划和其他相关证明。我也已经准备好了签证申请表和签证申请文件,以便准备签证。<br/><br/>如果您有任何问题,请随时联系我,我会尽快回复。<br/><br/>谢谢您的帮助!<br/><br/>谢谢,<br/><br/>[您的姓名] | Repetition Penalty=1.3 |
|
317 |
+
| 请你帮我写一封感谢数学老师的信 | 尊敬的数学老师,<br/><br/>非常感谢您对我的教育所做出的贡献。您的课程不仅深刻地阐述了数学知识,而且给我提供了有效的实践技巧,使我能够更好地理解和应用数学知识。<br/><br/>您的课程也让我发现数学是一门有趣的学科,并且可以在日常生活中得到应用。您的教学方式令我感到很有趣,同时也让我变得更加有趣。<br/><br/>感谢您对我的教育所做出的贡献!<br/><br/>谢谢! | Repetition Penalty=2.0 |
|
318 |
+
| 我得了流感,请帮我写一封假条。 | 尊敬的领导: <br/><br/>感谢您对我的帮助和支持。我很抱歉地告诉您,我今天感觉不太好,因为我得了流感。我希望能够快速康复,但是我需要一些时间来缓解疼痛。我会尽快完成任务,但是我需要一些时间来缓解疼痛。<br/><br/>感谢您的理解和支持。<br/><br/>谢谢,<br/>[你的名字] | Repetition Penalty=1.3 |
|
319 |
+
| 你是一个作文生成器,请以《绿水青山就是金山银山》为题写一篇议论文。 | 《绿水青山就是金山银山》是一篇议论文,它探讨了当今社会中的环境问题,以及我们应该如何应对这些问题。这篇议论文指出,当今社会面临着许多环境问题,如气候变化、空气污染、水资源短缺等,这些问题对我们的健康和生活产生了很大的影响。因此,我们需要采取措施来解决这些问题,以保护我们的环境和我们的未来。 | Repetition Penalty=1.0 |
|
320 |
+
| | | |
|
321 |
+
| | | |
|
322 |
+
|
323 |
+
</details>
|
324 |
+
|
325 |
+
|
326 |
+
<details>
|
327 |
+
<summary><b>Coding</b></summary>
|
328 |
+
|
329 |
+
**1、prompt: 请用python实现快速排序的代码(Repetition Penalty=1.0)**
|
330 |
+
|
331 |
+
```python
|
332 |
+
def quick_sort(arr):
|
333 |
+
if len(arr) <= 1:
|
334 |
+
return arr
|
335 |
+
pivot = arr[len(arr)//2]
|
336 |
+
left = []
|
337 |
+
right = []
|
338 |
+
for i in range(0, len(arr)):
|
339 |
+
if arr[i] < pivot:
|
340 |
+
left.append(arr[i])
|
341 |
+
else:
|
342 |
+
right.append(arr[i])
|
343 |
+
return quick_sort(left) + [pivot] + quick_sort(right)
|
344 |
+
```
|
345 |
+
|
346 |
+
**2**
|
347 |
+
|
348 |
+
- **prompt: 你能用python写一下Dijkstra算法的代码吗?**
|
349 |
+
|
350 |
+
很抱歉,我不能用Python写Dijkstra算法的代码,因为Dijkstra算法需要使用树结构,而Python不支持树结构。
|
351 |
+
|
352 |
+
- **prompt: 用python写一下Dijkstra算法的代码**
|
353 |
+
|
354 |
+
```python
|
355 |
+
def dijkstra(graph, source, destination):
|
356 |
+
"""
|
357 |
+
Dijkstra算法的代码
|
358 |
+
"""
|
359 |
+
# 创建一个空字典,用于存储所有节点的距离
|
360 |
+
distances = {}
|
361 |
+
# 创建一个链表,用于存储所有节点的距离
|
362 |
+
path = []
|
363 |
+
# 遍历所有节点
|
364 |
+
for node in graph:
|
365 |
+
# 如果节点是源节点
|
366 |
+
if node[0] == source:
|
367 |
+
# 记录距离为0
|
368 |
+
distances[node[0]] = 0
|
369 |
+
# 如果节点是目的地
|
370 |
+
elif node[0] == destination:
|
371 |
+
# 记录距离为0
|
372 |
+
distances[node[0]] = 0
|
373 |
+
# 如果节点是递归节点
|
374 |
+
else:
|
375 |
+
# 遍历所有递归节点
|
376 |
+
for neighbor in node[1]:
|
377 |
+
# 如果递归节点与源节点之间的距离小于距离
|
378 |
+
if distances[neighbor[0]] > distances[node[0]]:
|
379 |
+
# 记录距离
|
380 |
+
distances[node[0]] = distances[neighbor[0]] + 1
|
381 |
+
```
|
382 |
+
|
383 |
+
- **prompt: 用python写一下不带注释的Dijkstra算法的代码**
|
384 |
+
|
385 |
+
```python
|
386 |
+
def dijkstra(graph, source, destination):
|
387 |
+
"""
|
388 |
+
Dijkstra算法
|
389 |
+
"""
|
390 |
+
# 初始化变量
|
391 |
+
dist = {}
|
392 |
+
visited = {}
|
393 |
+
shortest_path = {}
|
394 |
+
current_node = {}
|
395 |
+
|
396 |
+
# 遍历所有节点
|
397 |
+
for node in graph:
|
398 |
+
dist[node] = 0
|
399 |
+
visited[node] = False
|
400 |
+
current_node[node] = node
|
401 |
+
|
402 |
+
# 遍历所有节点
|
403 |
+
for node in graph:
|
404 |
+
if not visited[node] and node!= destination:
|
405 |
+
dist[node] = max(dist[node], dist[source] + 1)
|
406 |
+
visited[node] = True
|
407 |
+
shortest_path[node] = source
|
408 |
+
current_node[node] = node
|
409 |
+
|
410 |
+
# 返回最短路径
|
411 |
+
return shortest_path[destination]
|
412 |
+
```
|
413 |
+
|
414 |
+
**3、Prompt: 请用PyTorch编写一个带ReLU激活函数的全连接层**
|
415 |
+
|
416 |
+
```python
|
417 |
+
python
|
418 |
+
import torch
|
419 |
+
from torch.nn.modules.relu import ReLU
|
420 |
+
from torch.nn.init import initializer
|
421 |
+
|
422 |
+
class FullConnected(torch.nn.Module):
|
423 |
+
def __init__(self):
|
424 |
+
super(FullConnected, self).__init__()
|
425 |
+
self.relu = ReLU()
|
426 |
+
self.init = initializer.Normal(stddev=0.01)
|
427 |
+
|
428 |
+
def forward(self, x):
|
429 |
+
x = self.init(x)
|
430 |
+
x = self.relu(x)
|
431 |
+
return x
|
432 |
+
```
|
433 |
+
|
434 |
+
**4、Prompt: 请写出计算两个数的最大公约数(GCD)的Python代码**
|
435 |
+
|
436 |
+
```python
|
437 |
+
def gcd(a, b):
|
438 |
+
if a == 0:
|
439 |
+
return b
|
440 |
+
else:
|
441 |
+
return gcd(b, a % b)
|
442 |
+
```
|
443 |
+
</details>
|
444 |
+
|
445 |
+
|
446 |
+
</details>
|
447 |
+
|
448 |
+
## What we need?
|
449 |
+
|
450 |
+
- code:
|
451 |
+
|
452 |
+
- This code is developed based on alpaca-lora,https://github.com/tloen/alpaca-lora
|
453 |
+
- This is a relatively simple set of code, the basic idea is to use PEFT's lora interface + transformer's trainer + instruction data configuration
|
454 |
+
|
455 |
+
- data:
|
456 |
+
|
457 |
+
- We have utilized several current high quality open source data and are very grateful for their contributions. Many of these data use chatgpt's interface like alpaca to generate high quality INSTRUCTION data.
|
458 |
+
|
459 |
+
- [Belle](https://github.com/LianjiaTech/BELLE)
|
460 |
+
- [guanaco](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
461 |
+
|
462 |
+
- The data format is relatively simple, basically as follows, with simple examples such as: [`./sample/merge_sample.json`](https://github.com/Facico/Chinese-Vicuna/blob/master/sample/merge_sample.json)
|
463 |
+
|
464 |
+
- ```
|
465 |
+
{
|
466 |
+
'instruction':
|
467 |
+
'input':
|
468 |
+
'output'
|
469 |
+
}
|
470 |
+
```
|
471 |
+
|
472 |
+
- That is, an instruction, an input, and an output are required. since the data is processed by directly linking instruction and input, the data can actually require only instruction and output, as
|
473 |
+
|
474 |
+
```
|
475 |
+
{
|
476 |
+
'instruction': "用一句话描述地球为什么是独一无二的。\\n\n"
|
477 |
+
'input': ""
|
478 |
+
'output': "地球上有适宜生命存在的条件和多样化的生命形式。"
|
479 |
+
}
|
480 |
+
```
|
481 |
+
|
482 |
+
|
483 |
+
|
484 |
+
- The data we currently integrate is available for download on BaiduDownload or Google Drive or HuggingFace
|
485 |
+
|
486 |
+
- link: https://pan.baidu.com/s/1WSxuhSAotl14ifaAiz5eKw?pwd=b4kb password: b4kb
|
487 |
+
- link: https://drive.google.com/file/d/1tzXVhS74m-EtoFot7hEc005LDeZGPit_/view?usp=sharing
|
488 |
+
- link: https://huggingface.co/datasets/Chinese-Vicuna/guanaco_belle_merge_v1.0
|
489 |
+
|
490 |
+
- Large Language Model:
|
491 |
+
|
492 |
+
- LLAMA 7B(Of course, if you have a larger machine(such as 3090Ti) can be replaced with a 13B, LLAMA13B is numerically superior to 175B GPT3)
|
493 |
+
|
494 |
+
- LORA model:
|
495 |
+
|
496 |
+
- We provide some lora models trained on the above mixed data,
|
497 |
+
- You can also load our or other models from huggingface, load it by referring to [generate.py](https://github.com/Facico/Chinese-Vicuna/blob/master/generate.py)
|
498 |
+
- `Chinese-Vicuna/Chinese-Vicuna-lora-7b-belle-and-guanaco`
|
499 |
+
- `Chinese-Vicuna/Chinese-Vicuna-lora-13b-belle-and-guanaco`
|
500 |
+
- The model uses 8bit+lora+256 tokens
|
501 |
+
- For more LORA model, please see: https://huggingface.co/Chinese-Vicuna
|
502 |
+
|
503 |
+
- Device:
|
504 |
+
|
505 |
+
- Training: A 2080Ti is sufficient. Since the data length is within 256, it takes about 9G of video memory.
|
506 |
+
- 70w of data, 3 epochs, a 2080Ti about 200h
|
507 |
+
- 13B need about 18G(the cutoff_len can be set to 2048 in 3090Ti/4090Ti)
|
508 |
+
- Inference: A 2080Ti is all you need(7B), multiple GPU inference support 。
|
509 |
+
- CPU Inference is also support! please go to see [`tools`](https://github.com/Facico/Chinese-Vicuna/blob/master/tools)
|
510 |
+
|
511 |
+
## How to use
|
512 |
+
|
513 |
+
**Installation**
|
514 |
+
|
515 |
+
```
|
516 |
+
git clone https://github.com/Facico/Chinese-Vicuna
|
517 |
+
pip install -r requirements.txt
|
518 |
+
```
|
519 |
+
|
520 |
+
Local python environment is 3.8, torch is 1.13.1, CUDA is 12
|
521 |
+
|
522 |
+
NOTE: python3.11 has a known `torchrun` bug, details [here](https://github.com/facebookresearch/llama/issues/86)
|
523 |
+
|
524 |
+
|
525 |
+
### Newest Version=>4bit(qlora)/multi-gpu inference
|
526 |
+
```
|
527 |
+
pip install -r requirements_4bit.txt
|
528 |
+
```
|
529 |
+
This environment will encounter saving problems when training 8bit, which has not been solved yet(https://github.com/TimDettmers/bitsandbytes/issues/324)
|
530 |
+
|
531 |
+
|
532 |
+
**Multi-gpu Training**
|
533 |
+
#### for instruction tuning
|
534 |
+
**8bit**
|
535 |
+
|
536 |
+
```bash
|
537 |
+
bash scripts/finetune.sh
|
538 |
+
```
|
539 |
+
|
540 |
+
- The parameters to note here are as follows
|
541 |
+
- TOT_CUDA, fill in the GPU number to be used, such as `TOT_CUDA="0,1,2,3"`
|
542 |
+
- PORT, fill in the corresponding port
|
543 |
+
- DATA_PATH,fill in the corresponding data location in the format of json
|
544 |
+
- OUTPUT_PATH,fill in the relative path to save the model
|
545 |
+
- MODEL_PATH,path of LLM
|
546 |
+
- wandb: This is a training visualization tool that is not turned on by default in the script, and can be turned on by adding "--wandb" to the script
|
547 |
+
|
548 |
+
|
549 |
+
**4bit**
|
550 |
+
```bash
|
551 |
+
bash scripts/finetune_4bit.sh
|
552 |
+
```
|
553 |
+
|
554 |
+
#### for conversational instruction tuning
|
555 |
+
|
556 |
+
```bash
|
557 |
+
bash scripts/finetune_chat.sh
|
558 |
+
```
|
559 |
+
|
560 |
+
#### For the case where 8bit cannot be turned on / for commanded trimming of fp16
|
561 |
+
```bash
|
562 |
+
bash scripts/finetune_deepspeed.sh
|
563 |
+
```
|
564 |
+
|
565 |
+
- use_deepspeed:set to 1:use deepspeed. Otherwise use fp16
|
566 |
+
|
567 |
+
**Single-gpu Training**
|
568 |
+
|
569 |
+
```
|
570 |
+
CUDA_VISIBLE_DEVICES=0 python finetune.py --data_path merge.json --test_size 2000
|
571 |
+
```
|
572 |
+
|
573 |
+
- The test_size cannot be larger than the data size
|
574 |
+
|
575 |
+
**inference and use gradio to generate a web page**
|
576 |
+
|
577 |
+
```bash
|
578 |
+
bash scripts/generate.sh
|
579 |
+
```
|
580 |
+
|
581 |
+
- The parameters to note here are as follows
|
582 |
+
|
583 |
+
- BASE_MODEL,path of LLM
|
584 |
+
- LORA_PATH,The checkpoint folder of the lora model
|
585 |
+
- It should be noted here that the config loaded by the lora model must be "adapter_config.json" and the model name must be "adapter_model.bin", but it will be automatically saved as "pytorch_model.bin" during training. pytorch_model.bin" during training, while "adapter_config.json" and "adapter_model.bin" will be saved after all training is finished
|
586 |
+
- If you load the lora model in the training checkpoint, the code will automatically copy the local "config-sample/adapter_config.json" to the corresponding directory for you and rename the "pytorch_model.bin" to "adapter_model.bin". and rename "pytorch_model.bin" to "adapter_model.bin".
|
587 |
+
- It can also be any lora model on the huggingface corresponding to llama 7B, e.g.: `Facico/Chinese-Vicuna-lora-7b-3epoch-belle-and-guanaco`
|
588 |
+
- USE_LOCAL, which checks the local model configuration when set to 1
|
589 |
+
- When using, "max_tokens" is set according to your computer's video memory, and if the generated content generates a lot of duplicate information, you can turn up the "Repetition Penalty".
|
590 |
+
|
591 |
+
|
592 |
+
|
593 |
+
**Multi-round interaction**
|
594 |
+
|
595 |
+
We implemented our own chatbot with streaming output (typewriter-style) using `gradio`, supporting beam search, repetiion penalty settings, the ability to clear history, select different global instruction, etc.
|
596 |
+
|
597 |
+
```bash
|
598 |
+
bash scripts/chat_7B.sh
|
599 |
+
```
|
600 |
+
|
601 |
+
- A simple interactive interface constructed using gradio, which allows you to set the max_memory according to your machine (it will intercept the max_memory part later in the history conversation)
|
602 |
+
|
603 |
+
- The prompt used in this script is not quite the same as the one used in generate.sh. The prompt in this script is in the form of a dialogue, as follows
|
604 |
+
|
605 |
+
- ```
|
606 |
+
The following is a conversation between an AI assistant called Bot and a human user called User.
|
607 |
+
```
|
608 |
+
|
609 |
+
At the same time, for a better interactive experience,
|
610 |
+
|
611 |
+
## Checkpoint Retraining/Incremental Training
|
612 |
+
|
613 |
+
Considering the possibility that the program may be disconnected in the middle of the process, or the need to continue training on vertical domain data, we have provided corresponding interfaces.
|
614 |
+
|
615 |
+
The following are the default multi-GPU scripts. Please modify the single-GPU situation according to the above instruction(run directly in Python)
|
616 |
+
|
617 |
+
**Checkpoint Retraining**
|
618 |
+
|
619 |
+
```bash
|
620 |
+
bash scripts/finetune_continue.sh
|
621 |
+
```
|
622 |
+
|
623 |
+
- Set the `lora_checkpoint`
|
624 |
+
|
625 |
+
- If there are optimizer (optimizer.pt), lr policy (scheduler.pt), and other files in this directory, they will be automatically loaded and retrained from where they were broken
|
626 |
+
|
627 |
+
- If there are only LORA related models (adapter_model.bin) and configurations (adapter_config.json) in this directory, they will be loaded and trained from scratch
|
628 |
+
|
629 |
+
- `from_data_beginning`: The parameter indicates whether to start training from the beginning of the data when loading (default: starting training from the place where the data is disconnected)
|
630 |
+
|
631 |
+
**Incremental Training**
|
632 |
+
|
633 |
+
Of course, you can choose to continue training directly from a trained Lora model using the above script (without loading any optimizer parameters)
|
634 |
+
|
635 |
+
You can also continue training from our optimizer parameters
|
636 |
+
|
637 |
+
```
|
638 |
+
finetune_others_continue.sh
|
639 |
+
```
|
640 |
+
|
641 |
+
- `from_data_beginning`: This will default to training from the beginning of the data
|
642 |
+
|
643 |
+
The logic of this script is mainly to keep the learning rate consistent. If your `max_steps` is smaller than ours, keep `max_steps `consistent with our `max_steps` during training, which is equivalent to putting your data directly behind our disconnected data; if your data set larger than us and will remain directly unchanged.
|
644 |
+
|
645 |
+
|
646 |
+
|
647 |
+
We currently directly provide checkpoints after 1 epoch and 2 epoch training
|
648 |
+
|
649 |
+
- 1epoch: https://github.com/Facico/Chinese-Vicuna/tree/master/lora-Vicuna/checkpoint-5800
|
650 |
+
- 2epoch: https://github.com/Facico/Chinese-Vicuna/tree/master/lora-Vicuna/checkpoint-11600
|
651 |
+
- If you use our checkpoint, your program will also continue from the corresponding step
|
652 |
+
|
653 |
+
### Specific cases
|
654 |
+
|
655 |
+
- Continue-finetune on the vertical corpus of medicalQA , see here [Chinese-Vicuna-medical](https://github.com/Facico/Chinese-Vicuna/blob/master/docs/performance-medical.md)
|
656 |
+
## **inference on CPU with pure C++**
|
657 |
+
|
658 |
+
Details in `tools` [readme](https://github.com/Facico/Chinese-Vicuna/blob/master/tools/readme.md)
|
659 |
+
|
660 |
+
## **More Tools**
|
661 |
+
|
662 |
+
We also offer:
|
663 |
+
- ways for faster weight download ( 8MB/s ) : [link](https://github.com/Facico/Chinese-Vicuna/blob/master/tools/download_llama.sh)
|
664 |
+
- convert tools between the original facebook checkpoint (`consolidated.xx.pth`) and huggingface format (`pytorch_model-000xx-of-000xx.bin`): [link](https://github.com/Facico/Chinese-Vicuna/blob/master/tools/convert_llama.py)
|
665 |
+
- a quantitative approach that requires less than 4G graphics memory for inference: [link](https://github.com/Facico/Chinese-Vicuna/blob/master/tools/llama_quant.py)
|
666 |
+
|
667 |
+
For more details, see [tool readme](https://github.com/Facico/Chinese-Vicuna/tree/master/tools)
|
668 |
+
|
669 |
+
|
670 |
+
# Todo
|
671 |
+
|
672 |
+
- [x] belle+guanaco(1.5 epoch, 8000 step)
|
673 |
+
- [x] belle+guanaco(100%)
|
674 |
+
- [x] Add more chitchat-like conversational corpus to enhance free conversation
|
675 |
+
- [x] Add colab training + lora loading interface
|
676 |
+
- [x] Add the interaction capabilities and typewrite-style output(beam search+streaming output)
|
677 |
+
- [x] Add llama c++ inference
|
678 |
+
- [x] Add gptq quantification tools
|
679 |
+
- [x] Add incremental training
|
680 |
+
- [x] train on multi-turn instruction dataset
|
681 |
+
- [x] train more epoch on cleaned instruct-chat combination data
|
682 |
+
- [x] train on domain-specific datas (medical, legal)
|
683 |
+
- [ ] add langchain
|
684 |
+
|
685 |
+
# Star History
|
686 |
+
[](https://star-history.com/#Facico/Chinese-Vicuna&Date)
|
687 |
+
|
688 |
+
# Citation
|
689 |
+
|
690 |
+
If you find this project useful in your research, please consider citing:
|
691 |
+
|
692 |
+
```
|
693 |
+
@inproceedings{leng2023chinese-vicuna,
|
694 |
+
title={Chinese-Vicuna: A Chinese Instruction-following LLaMA-based Model},
|
695 |
+
author={Chenghao Fan, Zhenyi Lu and Jie Tian},
|
696 |
+
url={https://github.com/Facico/Chinese-Vicuna},
|
697 |
+
year={2023}
|
698 |
+
}
|
699 |
+
```
|
chat.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from peft import PeftModel, PeftModelForCausalLM, LoraConfig
|
4 |
+
import transformers
|
5 |
+
import json
|
6 |
+
import gradio as gr
|
7 |
+
import argparse
|
8 |
+
import warnings
|
9 |
+
import os
|
10 |
+
from datetime import datetime
|
11 |
+
from utils import StreamPeftGenerationMixin,StreamLlamaForCausalLM, printf
|
12 |
+
import utils
|
13 |
+
import copy
|
14 |
+
assert (
|
15 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
16 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
17 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
18 |
+
import prompt
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
|
22 |
+
parser.add_argument("--lora_path", type=str, default='')
|
23 |
+
parser.add_argument("--use_typewriter", type=int, default=1)
|
24 |
+
parser.add_argument("--prompt_type", type=str, default='chat')
|
25 |
+
parser.add_argument("--share_link", type=int, default=0)
|
26 |
+
parser.add_argument("--show_beam", type=int, default=0)
|
27 |
+
parser.add_argument("--int8", type=int, default=1)
|
28 |
+
args = parser.parse_args()
|
29 |
+
args.fix_token = True
|
30 |
+
printf('>>> args:', args)
|
31 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
|
32 |
+
|
33 |
+
LOAD_8BIT = args.int8
|
34 |
+
BASE_MODEL = args.model_path
|
35 |
+
LORA_WEIGHTS = args.lora_path
|
36 |
+
|
37 |
+
# fix the path for local checkpoint
|
38 |
+
lora_bin_path = os.path.join(args.lora_path, "adapter_model.bin")
|
39 |
+
if args.lora_path != '' and os.path.exists(args.lora_path):
|
40 |
+
if not os.path.exists(lora_bin_path):
|
41 |
+
pytorch_bin_path = os.path.join(args.lora_path, "pytorch_model.bin")
|
42 |
+
printf('>>> load lora from', pytorch_bin_path)
|
43 |
+
if os.path.exists(pytorch_bin_path):
|
44 |
+
os.rename(pytorch_bin_path, lora_bin_path)
|
45 |
+
warnings.warn(
|
46 |
+
"The file name of the lora checkpoint'pytorch_model.bin' is replaced with 'adapter_model.bin'"
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
assert ('Checkpoint is not Found!')
|
50 |
+
else:
|
51 |
+
printf('>>> load lora from', lora_bin_path)
|
52 |
+
else:
|
53 |
+
printf('>>> load lora from huggingface url', args.lora_path)
|
54 |
+
|
55 |
+
if torch.cuda.is_available():
|
56 |
+
device = "cuda"
|
57 |
+
else:
|
58 |
+
device = "cpu"
|
59 |
+
|
60 |
+
try:
|
61 |
+
if torch.backends.mps.is_available():
|
62 |
+
device = "mps"
|
63 |
+
except:
|
64 |
+
pass
|
65 |
+
|
66 |
+
if device == "cuda":
|
67 |
+
print(f'>>> load raw models from {BASE_MODEL}')
|
68 |
+
if args.lora_path == "":
|
69 |
+
model = StreamLlamaForCausalLM.from_pretrained(
|
70 |
+
BASE_MODEL,
|
71 |
+
load_in_8bit=LOAD_8BIT,
|
72 |
+
torch_dtype=torch.float16,
|
73 |
+
device_map={"": 0},
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
print(f'>>> load lora models from {LORA_WEIGHTS}')
|
77 |
+
model = LlamaForCausalLM.from_pretrained(
|
78 |
+
BASE_MODEL,
|
79 |
+
load_in_8bit=LOAD_8BIT,
|
80 |
+
torch_dtype=torch.float16,
|
81 |
+
device_map={"": 0},
|
82 |
+
)
|
83 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
84 |
+
model, LORA_WEIGHTS, torch_dtype=torch.float16, load_in_8bit=LOAD_8BIT, device_map={"": 0}
|
85 |
+
)
|
86 |
+
elif device == "mps":
|
87 |
+
model = LlamaForCausalLM.from_pretrained(
|
88 |
+
BASE_MODEL,
|
89 |
+
device_map={"": device},
|
90 |
+
torch_dtype=torch.float16,
|
91 |
+
)
|
92 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
93 |
+
model,
|
94 |
+
LORA_WEIGHTS,
|
95 |
+
device_map={"": device},
|
96 |
+
torch_dtype=torch.float16,
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
model = LlamaForCausalLM.from_pretrained(
|
100 |
+
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
|
101 |
+
)
|
102 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
103 |
+
model,
|
104 |
+
LORA_WEIGHTS,
|
105 |
+
device_map={"": device},
|
106 |
+
)
|
107 |
+
# fix tokenizer bug
|
108 |
+
if args.fix_token and tokenizer.eos_token_id != 2:
|
109 |
+
warnings.warn(
|
110 |
+
"The tokenizer eos token may be wrong. please check you llama-checkpoint"
|
111 |
+
)
|
112 |
+
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
113 |
+
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
114 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # same as unk token id
|
115 |
+
if not LOAD_8BIT:
|
116 |
+
model.half() # seems to fix bugs for some users.
|
117 |
+
|
118 |
+
model.eval()
|
119 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
120 |
+
model = torch.compile(model)
|
121 |
+
|
122 |
+
def save(
|
123 |
+
inputs,
|
124 |
+
history,
|
125 |
+
temperature=0.1,
|
126 |
+
top_p=0.75,
|
127 |
+
top_k=40,
|
128 |
+
num_beams=4,
|
129 |
+
max_new_tokens=128,
|
130 |
+
min_new_tokens=1,
|
131 |
+
repetition_penalty=2.0,
|
132 |
+
max_memory=1024,
|
133 |
+
do_sample=False,
|
134 |
+
prompt_type='0',
|
135 |
+
**kwargs,
|
136 |
+
):
|
137 |
+
history = [] if history is None else history
|
138 |
+
data_point = {}
|
139 |
+
if prompt_type == 'instruct':
|
140 |
+
PROMPT = prompt.instruct_prompt(tokenizer,max_memory)
|
141 |
+
elif prompt_type == 'chat':
|
142 |
+
PROMPT = prompt.chat_prompt(tokenizer,max_memory)
|
143 |
+
else:
|
144 |
+
raise Exception('not support')
|
145 |
+
data_point['history'] = history
|
146 |
+
# 实际上是每一步都可以不一样,这里只保存最后一步
|
147 |
+
data_point['generation_parameter'] = {
|
148 |
+
"temperature":temperature,
|
149 |
+
"top_p":top_p,
|
150 |
+
"top_k":top_k,
|
151 |
+
"num_beams":num_beams,
|
152 |
+
"bos_token_id":tokenizer.bos_token_id,
|
153 |
+
"eos_token_id":tokenizer.eos_token_id,
|
154 |
+
"pad_token_id":tokenizer.pad_token_id,
|
155 |
+
"max_new_tokens":max_new_tokens,
|
156 |
+
"min_new_tokens":min_new_tokens,
|
157 |
+
"do_sample":do_sample,
|
158 |
+
"repetition_penalty":repetition_penalty,
|
159 |
+
"max_memory":max_memory,
|
160 |
+
}
|
161 |
+
data_point['info'] = args.__dict__
|
162 |
+
print(data_point)
|
163 |
+
if args.int8:
|
164 |
+
file_name = f"{args.lora_path}/{args.prompt_type.replace(' ','_')}_int8.jsonl"
|
165 |
+
else:
|
166 |
+
file_name = f"{args.lora_path}/{args.prompt_type.replace(' ','_')}_fp16.jsonl"
|
167 |
+
utils.to_jsonl([data_point], file_name)
|
168 |
+
|
169 |
+
def evaluate(
|
170 |
+
inputs,
|
171 |
+
history,
|
172 |
+
temperature=0.1,
|
173 |
+
top_p=0.75,
|
174 |
+
top_k=40,
|
175 |
+
num_beams=4,
|
176 |
+
max_new_tokens=128,
|
177 |
+
min_new_tokens=1,
|
178 |
+
repetition_penalty=2.0,
|
179 |
+
max_memory=1024,
|
180 |
+
do_sample=False,
|
181 |
+
prompt_type='0',
|
182 |
+
**kwargs,
|
183 |
+
):
|
184 |
+
history = [] if history is None else history
|
185 |
+
data_point = {}
|
186 |
+
if prompt_type == 'instruct':
|
187 |
+
PROMPT = prompt.instruct_prompt(tokenizer,max_memory)
|
188 |
+
elif prompt_type == 'chat':
|
189 |
+
PROMPT = prompt.chat_prompt(tokenizer,max_memory)
|
190 |
+
else:
|
191 |
+
raise Exception('not support')
|
192 |
+
|
193 |
+
data_point['history'] = copy.deepcopy(history)
|
194 |
+
data_point['input'] = inputs
|
195 |
+
|
196 |
+
input_ids = PROMPT.preprocess_gen(data_point)
|
197 |
+
|
198 |
+
printf('------------------------------')
|
199 |
+
printf(tokenizer.decode(input_ids))
|
200 |
+
input_ids = torch.tensor([input_ids]).to(device) # batch=1
|
201 |
+
|
202 |
+
printf('------------------------------')
|
203 |
+
printf('shape',input_ids.size())
|
204 |
+
printf('------------------------------')
|
205 |
+
generation_config = GenerationConfig(
|
206 |
+
temperature=temperature,
|
207 |
+
top_p=top_p,
|
208 |
+
top_k=top_k,
|
209 |
+
num_beams=num_beams,
|
210 |
+
bos_token_id=tokenizer.bos_token_id,
|
211 |
+
eos_token_id=tokenizer.eos_token_id,
|
212 |
+
pad_token_id=tokenizer.pad_token_id,
|
213 |
+
max_new_tokens=max_new_tokens, # max_length=max_new_tokens+input_sequence
|
214 |
+
min_new_tokens=min_new_tokens, # min_length=min_new_tokens+input_sequence
|
215 |
+
do_sample=do_sample,
|
216 |
+
bad_words_ids=tokenizer(['\n\nUser:','\n\nAssistant:'], add_special_tokens=False).input_ids,
|
217 |
+
|
218 |
+
**kwargs,
|
219 |
+
)
|
220 |
+
|
221 |
+
return_text = [(item['input'], item['output']) for item in history]
|
222 |
+
out_memory =False
|
223 |
+
outputs = None
|
224 |
+
with torch.no_grad():
|
225 |
+
# 流式输出 / 打字机效果
|
226 |
+
# streamly output / typewriter style
|
227 |
+
if args.use_typewriter:
|
228 |
+
try:
|
229 |
+
for generation_output in model.stream_generate(
|
230 |
+
input_ids=input_ids,
|
231 |
+
generation_config=generation_config,
|
232 |
+
return_dict_in_generate=True,
|
233 |
+
output_scores=False,
|
234 |
+
repetition_penalty=float(repetition_penalty),
|
235 |
+
):
|
236 |
+
gen_token = generation_output[0][-1].item()
|
237 |
+
printf(gen_token, end='(')
|
238 |
+
printf(tokenizer.decode(gen_token), end=') ')
|
239 |
+
|
240 |
+
outputs = tokenizer.batch_decode(generation_output)
|
241 |
+
if args.show_beam:
|
242 |
+
show_text = "\n--------------------------------------------\n".join(
|
243 |
+
[ PROMPT.postprocess(output)+" ▌" for output in outputs]
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
show_text = PROMPT.postprocess(outputs[0])+" ▌"
|
247 |
+
yield return_text +[(inputs, show_text)], history
|
248 |
+
except torch.cuda.OutOfMemoryError:
|
249 |
+
print('CUDA out of memory')
|
250 |
+
import gc
|
251 |
+
gc.collect()
|
252 |
+
torch.cuda.empty_cache()
|
253 |
+
out_memory=True
|
254 |
+
# finally only one
|
255 |
+
printf('[EOS]', end='\n')
|
256 |
+
show_text = PROMPT.postprocess(outputs[0] if outputs is not None else '### Response:')
|
257 |
+
return_len = len(show_text)
|
258 |
+
if out_memory==True:
|
259 |
+
out_memory=False
|
260 |
+
show_text+= '<p style="color:#FF0000"> [GPU Out Of Memory] </p> '
|
261 |
+
if return_len > 0:
|
262 |
+
output = PROMPT.postprocess(outputs[0], render=False)
|
263 |
+
history.append({
|
264 |
+
'input': inputs,
|
265 |
+
'output': output,
|
266 |
+
})
|
267 |
+
|
268 |
+
return_text += [(inputs, show_text)]
|
269 |
+
yield return_text, history
|
270 |
+
# common
|
271 |
+
else:
|
272 |
+
try:
|
273 |
+
generation_output = model.generate(
|
274 |
+
input_ids=input_ids,
|
275 |
+
generation_config=generation_config,
|
276 |
+
return_dict_in_generate=True,
|
277 |
+
output_scores=True,
|
278 |
+
max_new_tokens=max_new_tokens,
|
279 |
+
repetition_penalty=float(repetition_penalty),
|
280 |
+
)
|
281 |
+
s = generation_output.sequences[0]
|
282 |
+
output = tokenizer.decode(s)
|
283 |
+
output = PROMPT.postprocess(output)
|
284 |
+
history.append({
|
285 |
+
'input': inputs,
|
286 |
+
'output': output,
|
287 |
+
})
|
288 |
+
return_text += [(inputs, output)]
|
289 |
+
yield return_text, history
|
290 |
+
except torch.cuda.OutOfMemoryError:
|
291 |
+
import gc
|
292 |
+
gc.collect()
|
293 |
+
torch.cuda.empty_cache()
|
294 |
+
show_text = '<p style="color:#FF0000"> [GPU Out Of Memory] </p> '
|
295 |
+
printf(show_text)
|
296 |
+
return_text += [(inputs, show_text)]
|
297 |
+
yield return_text, history
|
298 |
+
|
299 |
+
def clear():
|
300 |
+
import gc
|
301 |
+
gc.collect()
|
302 |
+
torch.cuda.empty_cache()
|
303 |
+
return None, None
|
304 |
+
|
305 |
+
|
306 |
+
# gr.Interface对chatbot的clear有bug,因此我们重新实现了一个基于gr.block的UI逻辑
|
307 |
+
# gr.Interface has bugs to clear chatbot's history,so we customly implement it based on gr.block
|
308 |
+
with gr.Blocks() as demo:
|
309 |
+
fn = evaluate
|
310 |
+
title = gr.Markdown(
|
311 |
+
"<h1 style='text-align: center; margin-bottom: 1rem'>"
|
312 |
+
+ "Chinese-Vicuna 中文小羊驼"
|
313 |
+
+ "</h1>"
|
314 |
+
)
|
315 |
+
description = gr.Markdown(
|
316 |
+
"中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来,模型基于开源的llama7B,主要贡献是对应的lora模型。由于代码训练资源要求较小,希望为llama中文lora社区做一份贡献。"
|
317 |
+
)
|
318 |
+
history = gr.components.State()
|
319 |
+
with gr.Row().style(equal_height=False):
|
320 |
+
with gr.Column(variant="panel"):
|
321 |
+
input_component_column = gr.Column()
|
322 |
+
with input_component_column:
|
323 |
+
input = gr.components.Textbox(
|
324 |
+
lines=2, label="Input", placeholder="请输入问题."
|
325 |
+
)
|
326 |
+
temperature = gr.components.Slider(minimum=0, maximum=1, value=1.0, label="Temperature")
|
327 |
+
topp = gr.components.Slider(minimum=0, maximum=1, value=0.9, label="Top p")
|
328 |
+
topk = gr.components.Slider(minimum=0, maximum=100, step=1, value=60, label="Top k")
|
329 |
+
beam_number = gr.components.Slider(minimum=1, maximum=10, step=1, value=4, label="Beams Number")
|
330 |
+
max_new_token = gr.components.Slider(
|
331 |
+
minimum=1, maximum=2048, step=1, value=256, label="Max New Tokens"
|
332 |
+
)
|
333 |
+
min_new_token = gr.components.Slider(
|
334 |
+
minimum=1, maximum=1024, step=1, value=5, label="Min New Tokens"
|
335 |
+
)
|
336 |
+
repeat_penal = gr.components.Slider(
|
337 |
+
minimum=0.1, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"
|
338 |
+
)
|
339 |
+
max_memory = gr.components.Slider(
|
340 |
+
minimum=0, maximum=2048, step=1, value=2048, label="Max Memory"
|
341 |
+
)
|
342 |
+
do_sample = gr.components.Checkbox(label="Use sample")
|
343 |
+
# must be str, not number !
|
344 |
+
type_of_prompt = gr.components.Dropdown(
|
345 |
+
['instruct', 'chat'], value=args.prompt_type, label="Prompt Type", info="select the specific prompt; use after clear history"
|
346 |
+
)
|
347 |
+
input_components = [
|
348 |
+
input, history, temperature, topp, topk, beam_number, max_new_token, min_new_token, repeat_penal, max_memory, do_sample, type_of_prompt
|
349 |
+
]
|
350 |
+
input_components_except_states = [input, temperature, topp, topk, beam_number, max_new_token, min_new_token, repeat_penal, max_memory, do_sample, type_of_prompt]
|
351 |
+
with gr.Row():
|
352 |
+
cancel_btn = gr.Button('Cancel')
|
353 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
354 |
+
stop_btn = gr.Button("Stop", variant="stop", visible=False)
|
355 |
+
with gr.Row():
|
356 |
+
reset_btn = gr.Button("Reset Parameter")
|
357 |
+
clear_history = gr.Button("Clear History")
|
358 |
+
|
359 |
+
|
360 |
+
with gr.Column(variant="panel"):
|
361 |
+
chatbot = gr.Chatbot().style(height=1024)
|
362 |
+
output_components = [ chatbot, history ]
|
363 |
+
with gr.Row():
|
364 |
+
save_btn = gr.Button("Save Chat")
|
365 |
+
def wrapper(*args):
|
366 |
+
# here to support the change between the stop and submit button
|
367 |
+
try:
|
368 |
+
for output in fn(*args):
|
369 |
+
output = [o for o in output]
|
370 |
+
# output for output_components, the rest for [button, button]
|
371 |
+
yield output + [
|
372 |
+
gr.Button.update(visible=False),
|
373 |
+
gr.Button.update(visible=True),
|
374 |
+
]
|
375 |
+
finally:
|
376 |
+
yield [{'__type__': 'generic_update'}, {'__type__': 'generic_update'}] + [ gr.Button.update(visible=True), gr.Button.update(visible=False)]
|
377 |
+
|
378 |
+
def cancel(history, chatbot):
|
379 |
+
if history == []:
|
380 |
+
return (None, None)
|
381 |
+
return history[:-1], chatbot[:-1]
|
382 |
+
|
383 |
+
extra_output = [submit_btn, stop_btn]
|
384 |
+
save_btn.click(
|
385 |
+
save,
|
386 |
+
input_components,
|
387 |
+
None,
|
388 |
+
)
|
389 |
+
pred = submit_btn.click(
|
390 |
+
wrapper,
|
391 |
+
input_components,
|
392 |
+
output_components + extra_output,
|
393 |
+
api_name="predict",
|
394 |
+
scroll_to_output=True,
|
395 |
+
preprocess=True,
|
396 |
+
postprocess=True,
|
397 |
+
batch=False,
|
398 |
+
max_batch_size=4,
|
399 |
+
)
|
400 |
+
submit_btn.click(
|
401 |
+
lambda: (
|
402 |
+
submit_btn.update(visible=False),
|
403 |
+
stop_btn.update(visible=True),
|
404 |
+
),
|
405 |
+
inputs=None,
|
406 |
+
outputs=[submit_btn, stop_btn],
|
407 |
+
queue=False,
|
408 |
+
)
|
409 |
+
stop_btn.click(
|
410 |
+
lambda: (
|
411 |
+
submit_btn.update(visible=True),
|
412 |
+
stop_btn.update(visible=False),
|
413 |
+
),
|
414 |
+
inputs=None,
|
415 |
+
outputs=[submit_btn, stop_btn],
|
416 |
+
cancels=[pred],
|
417 |
+
queue=False,
|
418 |
+
)
|
419 |
+
cancel_btn.click(
|
420 |
+
cancel,
|
421 |
+
inputs=[history, chatbot],
|
422 |
+
outputs=[history, chatbot]
|
423 |
+
)
|
424 |
+
reset_btn.click(
|
425 |
+
None,
|
426 |
+
[],
|
427 |
+
(
|
428 |
+
# input_components ; don't work for history...
|
429 |
+
input_components_except_states
|
430 |
+
+ [input_component_column]
|
431 |
+
), # type: ignore
|
432 |
+
_js=f"""() => {json.dumps([
|
433 |
+
getattr(component, "cleared_value", None) for component in input_components_except_states ]
|
434 |
+
+ ([gr.Column.update(visible=True)])
|
435 |
+
+ ([])
|
436 |
+
)}
|
437 |
+
""",
|
438 |
+
)
|
439 |
+
clear_history.click(clear, None, [history, chatbot], queue=False)
|
440 |
+
|
441 |
+
demo.queue().launch(share=args.share_link)
|
finetune.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import bitsandbytes as bnb
|
7 |
+
from datasets import load_dataset
|
8 |
+
import transformers
|
9 |
+
import argparse
|
10 |
+
import warnings
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
|
13 |
+
assert (
|
14 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
15 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
16 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
17 |
+
from peft import (
|
18 |
+
prepare_model_for_int8_training,
|
19 |
+
LoraConfig,
|
20 |
+
get_peft_model,
|
21 |
+
get_peft_model_state_dict,
|
22 |
+
set_peft_model_state_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
parser.add_argument("--wandb", action="store_true", default=False)
|
27 |
+
parser.add_argument("--data_path", type=str, default="merge.json")
|
28 |
+
parser.add_argument("--output_path", type=str, default="lora-Vicuna")
|
29 |
+
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
|
30 |
+
parser.add_argument("--eval_steps", type=int, default=200)
|
31 |
+
parser.add_argument("--save_steps", type=int, default=200)
|
32 |
+
parser.add_argument("--test_size", type=int, default=200)
|
33 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
34 |
+
parser.add_argument("--lora_remote_checkpoint", type=str, default=None)
|
35 |
+
parser.add_argument("--ignore_data_skip", type=str, default="False")
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
if not args.wandb:
|
39 |
+
os.environ["WANDB_MODE"] = "disable"
|
40 |
+
# optimized for RTX 4090. for larger GPUs, increase some of these?
|
41 |
+
MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
|
42 |
+
BATCH_SIZE = 128
|
43 |
+
MAX_STEPS = None
|
44 |
+
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
45 |
+
EPOCHS = 3 # we don't always need 3 tbh
|
46 |
+
LEARNING_RATE = 3e-4 # the Karpathy constant
|
47 |
+
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
|
48 |
+
LORA_R = 8
|
49 |
+
LORA_ALPHA = 16
|
50 |
+
LORA_DROPOUT = 0.05
|
51 |
+
VAL_SET_SIZE = args.test_size #2000
|
52 |
+
USE_8bit = True
|
53 |
+
|
54 |
+
if USE_8bit is True:
|
55 |
+
warnings.warn("If your version of bitsandbytes>0.37.2, Please downgrade bitsandbytes's version, for example: pip install bitsandbytes==0.37.2")
|
56 |
+
|
57 |
+
TARGET_MODULES = [
|
58 |
+
"q_proj",
|
59 |
+
"v_proj",
|
60 |
+
]
|
61 |
+
DATA_PATH = args.data_path
|
62 |
+
OUTPUT_DIR = args.output_path #"lora-Vicuna"
|
63 |
+
|
64 |
+
device_map = "auto"
|
65 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
66 |
+
ddp = world_size != 1
|
67 |
+
if ddp:
|
68 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
69 |
+
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
70 |
+
print(args.model_path)
|
71 |
+
model = LlamaForCausalLM.from_pretrained(
|
72 |
+
args.model_path,
|
73 |
+
load_in_8bit=USE_8bit,
|
74 |
+
device_map=device_map,
|
75 |
+
)
|
76 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
77 |
+
args.model_path, add_eos_token=True
|
78 |
+
)
|
79 |
+
|
80 |
+
if USE_8bit is True:
|
81 |
+
model = prepare_model_for_int8_training(model)
|
82 |
+
|
83 |
+
config = LoraConfig(
|
84 |
+
r=LORA_R,
|
85 |
+
lora_alpha=LORA_ALPHA,
|
86 |
+
target_modules=TARGET_MODULES,
|
87 |
+
lora_dropout=LORA_DROPOUT,
|
88 |
+
bias="none",
|
89 |
+
task_type="CAUSAL_LM",
|
90 |
+
)
|
91 |
+
model = get_peft_model(model, config)
|
92 |
+
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
93 |
+
#tokenizer.padding_side = "left" # Allow batched inference
|
94 |
+
|
95 |
+
data = load_dataset("json", data_files=DATA_PATH)
|
96 |
+
|
97 |
+
now_max_steps = max((len(data["train"]) - VAL_SET_SIZE) // BATCH_SIZE * EPOCHS, EPOCHS)
|
98 |
+
if args.resume_from_checkpoint:
|
99 |
+
if args.lora_remote_checkpoint is not None:
|
100 |
+
snapshot_download(repo_id=args.lora_remote_checkpoint, allow_patterns=["*.pt", "*.bin", "*.json"], local_dir=args.resume_from_checkpoint)
|
101 |
+
# Check the available weights and load them
|
102 |
+
checkpoint_name = os.path.join(
|
103 |
+
args.resume_from_checkpoint, "pytorch_model.bin"
|
104 |
+
) # Full checkpoint
|
105 |
+
if not os.path.exists(checkpoint_name):
|
106 |
+
pytorch_bin_path = checkpoint_name
|
107 |
+
checkpoint_name = os.path.join(
|
108 |
+
args.resume_from_checkpoint, "adapter_model.bin"
|
109 |
+
) # only LoRA model - LoRA config above has to fit
|
110 |
+
if os.path.exists(checkpoint_name):
|
111 |
+
os.rename(checkpoint_name, pytorch_bin_path)
|
112 |
+
warnings.warn("The file name of the lora checkpoint'adapter_model.bin' is replaced with 'pytorch_model.bin'")
|
113 |
+
else:
|
114 |
+
args.resume_from_checkpoint = (
|
115 |
+
None # So the trainer won't try loading its state
|
116 |
+
)
|
117 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
118 |
+
if os.path.exists(checkpoint_name):
|
119 |
+
print(f"Restarting from {checkpoint_name}")
|
120 |
+
adapters_weights = torch.load(checkpoint_name)
|
121 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
122 |
+
else:
|
123 |
+
print(f"Checkpoint {checkpoint_name} not found")
|
124 |
+
|
125 |
+
train_args_path = os.path.join(args.resume_from_checkpoint, "trainer_state.json")
|
126 |
+
|
127 |
+
if os.path.exists(train_args_path):
|
128 |
+
import json
|
129 |
+
base_train_args = json.load(open(train_args_path, 'r'))
|
130 |
+
base_max_steps = base_train_args["max_steps"]
|
131 |
+
resume_scale = base_max_steps / now_max_steps
|
132 |
+
if base_max_steps > now_max_steps:
|
133 |
+
warnings.warn("epoch {} replace to the base_max_steps {}".format(EPOCHS, base_max_steps))
|
134 |
+
EPOCHS = None
|
135 |
+
MAX_STEPS = base_max_steps
|
136 |
+
else:
|
137 |
+
MAX_STEPS = now_max_steps
|
138 |
+
else:
|
139 |
+
MAX_STEPS = now_max_steps
|
140 |
+
|
141 |
+
|
142 |
+
model.print_trainable_parameters()
|
143 |
+
|
144 |
+
def generate_prompt(data_point):
|
145 |
+
# sorry about the formatting disaster gotta move fast
|
146 |
+
if data_point["input"]:
|
147 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
148 |
+
|
149 |
+
### Instruction:
|
150 |
+
{data_point["instruction"]}
|
151 |
+
|
152 |
+
### Input:
|
153 |
+
{data_point["input"]}
|
154 |
+
|
155 |
+
### Response:
|
156 |
+
{data_point["output"]}"""
|
157 |
+
else:
|
158 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
159 |
+
|
160 |
+
### Instruction:
|
161 |
+
{data_point["instruction"]}
|
162 |
+
|
163 |
+
### Response:
|
164 |
+
{data_point["output"]}"""
|
165 |
+
|
166 |
+
|
167 |
+
def tokenize(prompt):
|
168 |
+
# there's probably a way to do this with the tokenizer settings
|
169 |
+
# but again, gotta move fast
|
170 |
+
result = tokenizer(
|
171 |
+
prompt,
|
172 |
+
truncation=True,
|
173 |
+
max_length=CUTOFF_LEN + 1,
|
174 |
+
padding="max_length",
|
175 |
+
)
|
176 |
+
return {
|
177 |
+
"input_ids": result["input_ids"][:-1],
|
178 |
+
"attention_mask": result["attention_mask"][:-1],
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
+
def generate_and_tokenize_prompt(data_point):
|
183 |
+
# This function masks out the labels for the input,
|
184 |
+
# so that our loss is computed only on the response.
|
185 |
+
user_prompt = (
|
186 |
+
(
|
187 |
+
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
188 |
+
|
189 |
+
### Instruction:
|
190 |
+
{data_point["instruction"]}
|
191 |
+
|
192 |
+
### Input:
|
193 |
+
{data_point["input"]}
|
194 |
+
|
195 |
+
### Response:
|
196 |
+
"""
|
197 |
+
)
|
198 |
+
if data_point["input"]
|
199 |
+
else (
|
200 |
+
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
201 |
+
|
202 |
+
### Instruction:
|
203 |
+
{data_point["instruction"]}
|
204 |
+
|
205 |
+
### Response:
|
206 |
+
"""
|
207 |
+
)
|
208 |
+
)
|
209 |
+
len_user_prompt_tokens = (
|
210 |
+
len(
|
211 |
+
tokenizer(
|
212 |
+
user_prompt,
|
213 |
+
truncation=True,
|
214 |
+
max_length=CUTOFF_LEN + 1,
|
215 |
+
)["input_ids"]
|
216 |
+
)
|
217 |
+
- 1
|
218 |
+
) # no eos token
|
219 |
+
full_tokens = tokenizer(
|
220 |
+
user_prompt + data_point["output"],
|
221 |
+
truncation=True,
|
222 |
+
max_length=CUTOFF_LEN + 1,
|
223 |
+
padding="max_length",
|
224 |
+
)["input_ids"][:-1]
|
225 |
+
return {
|
226 |
+
"input_ids": full_tokens,
|
227 |
+
"labels": [-100] * len_user_prompt_tokens
|
228 |
+
+ full_tokens[len_user_prompt_tokens:],
|
229 |
+
"attention_mask": [1] * (len(full_tokens)),
|
230 |
+
}
|
231 |
+
|
232 |
+
|
233 |
+
if VAL_SET_SIZE > 0:
|
234 |
+
train_val = data["train"].train_test_split(
|
235 |
+
test_size=VAL_SET_SIZE, shuffle=True, seed=42
|
236 |
+
)
|
237 |
+
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
238 |
+
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
239 |
+
else:
|
240 |
+
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
241 |
+
val_data = None
|
242 |
+
|
243 |
+
trainer = transformers.Trainer(
|
244 |
+
model=model,
|
245 |
+
train_dataset=train_data,
|
246 |
+
eval_dataset=val_data,
|
247 |
+
args=transformers.TrainingArguments(
|
248 |
+
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
249 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
250 |
+
warmup_steps=100,
|
251 |
+
num_train_epochs=EPOCHS,
|
252 |
+
max_steps=MAX_STEPS,
|
253 |
+
learning_rate=LEARNING_RATE,
|
254 |
+
fp16=True,
|
255 |
+
logging_steps=20,
|
256 |
+
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
|
257 |
+
save_strategy="steps",
|
258 |
+
eval_steps=args.eval_steps if VAL_SET_SIZE > 0 else None,
|
259 |
+
save_steps=args.save_steps,
|
260 |
+
output_dir=OUTPUT_DIR,
|
261 |
+
save_total_limit=30,
|
262 |
+
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
|
263 |
+
ddp_find_unused_parameters=False if ddp else None,
|
264 |
+
report_to="wandb" if args.wandb else [],
|
265 |
+
ignore_data_skip=args.ignore_data_skip,
|
266 |
+
),
|
267 |
+
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
268 |
+
)
|
269 |
+
model.config.use_cache = False
|
270 |
+
|
271 |
+
old_state_dict = model.state_dict
|
272 |
+
model.state_dict = (
|
273 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
274 |
+
).__get__(model, type(model))
|
275 |
+
|
276 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
277 |
+
model = torch.compile(model)
|
278 |
+
|
279 |
+
print("\n If there's a warning about missing keys above, please disregard :)")
|
280 |
+
|
281 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
282 |
+
|
283 |
+
model.save_pretrained(OUTPUT_DIR)
|
284 |
+
|
finetune_4bit.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import bitsandbytes as bnb
|
7 |
+
from datasets import load_dataset, Dataset
|
8 |
+
import transformers
|
9 |
+
import argparse
|
10 |
+
import warnings
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
|
13 |
+
assert (
|
14 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
15 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
16 |
+
|
17 |
+
|
18 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig
|
19 |
+
from peft import (
|
20 |
+
prepare_model_for_kbit_training,
|
21 |
+
LoraConfig,
|
22 |
+
get_peft_model,
|
23 |
+
get_peft_model_state_dict,
|
24 |
+
set_peft_model_state_dict,
|
25 |
+
)
|
26 |
+
|
27 |
+
def generate_prompt(data_point):
|
28 |
+
# sorry about the formatting disaster gotta move fast
|
29 |
+
if data_point["input"]:
|
30 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
31 |
+
|
32 |
+
### Instruction:
|
33 |
+
{data_point["instruction"]}
|
34 |
+
|
35 |
+
### Input:
|
36 |
+
{data_point["input"]}
|
37 |
+
|
38 |
+
### Response:
|
39 |
+
{data_point["output"]}"""
|
40 |
+
else:
|
41 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
42 |
+
|
43 |
+
### Instruction:
|
44 |
+
{data_point["instruction"]}
|
45 |
+
|
46 |
+
### Response:
|
47 |
+
{data_point["output"]}"""
|
48 |
+
|
49 |
+
def tokenize(prompt):
|
50 |
+
# there's probably a way to do this with the tokenizer settings
|
51 |
+
# but again, gotta move fast
|
52 |
+
result = tokenizer(
|
53 |
+
prompt,
|
54 |
+
truncation=True,
|
55 |
+
max_length=CUTOFF_LEN + 1,
|
56 |
+
padding="max_length",
|
57 |
+
)
|
58 |
+
return {
|
59 |
+
"input_ids": result["input_ids"][:-1],
|
60 |
+
"attention_mask": result["attention_mask"][:-1],
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
def generate_and_tokenize_prompt(data_point):
|
65 |
+
# This function masks out the labels for the input,
|
66 |
+
# so that our loss is computed only on the response.
|
67 |
+
user_prompt = (
|
68 |
+
(
|
69 |
+
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
70 |
+
|
71 |
+
### Instruction:
|
72 |
+
{data_point["instruction"]}
|
73 |
+
|
74 |
+
### Input:
|
75 |
+
{data_point["input"]}
|
76 |
+
|
77 |
+
### Response:
|
78 |
+
"""
|
79 |
+
)
|
80 |
+
if data_point["input"]
|
81 |
+
else (
|
82 |
+
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
83 |
+
|
84 |
+
### Instruction:
|
85 |
+
{data_point["instruction"]}
|
86 |
+
|
87 |
+
### Response:
|
88 |
+
"""
|
89 |
+
)
|
90 |
+
)
|
91 |
+
len_user_prompt_tokens = (
|
92 |
+
len(
|
93 |
+
tokenizer(
|
94 |
+
user_prompt,
|
95 |
+
truncation=True,
|
96 |
+
max_length=CUTOFF_LEN + 1,
|
97 |
+
)["input_ids"]
|
98 |
+
)
|
99 |
+
- 1
|
100 |
+
) # no eos token
|
101 |
+
full_tokens = tokenizer(
|
102 |
+
user_prompt + data_point["output"],
|
103 |
+
truncation=True,
|
104 |
+
max_length=CUTOFF_LEN + 1,
|
105 |
+
padding="max_length",
|
106 |
+
)["input_ids"][:-1]
|
107 |
+
return {
|
108 |
+
"input_ids": full_tokens,
|
109 |
+
"labels": [-100] * len_user_prompt_tokens
|
110 |
+
+ full_tokens[len_user_prompt_tokens:],
|
111 |
+
"attention_mask": [1] * (len(full_tokens)),
|
112 |
+
}
|
113 |
+
|
114 |
+
parser = argparse.ArgumentParser()
|
115 |
+
parser.add_argument("--wandb", action="store_true", default=False)
|
116 |
+
parser.add_argument("--data_path", type=str, default="merge.json")
|
117 |
+
parser.add_argument("--output_path", type=str, default="lora-Vicuna")
|
118 |
+
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
|
119 |
+
parser.add_argument("--eval_steps", type=int, default=200)
|
120 |
+
parser.add_argument("--save_steps", type=int, default=200)
|
121 |
+
parser.add_argument("--test_size", type=int, default=200)
|
122 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
123 |
+
parser.add_argument("--lora_remote_checkpoint", type=str, default=None)
|
124 |
+
parser.add_argument("--ignore_data_skip", type=str, default="False")
|
125 |
+
args = parser.parse_args()
|
126 |
+
|
127 |
+
if not args.wandb:
|
128 |
+
os.environ["WANDB_MODE"] = "disable"
|
129 |
+
# optimized for RTX 4090. for larger GPUs, increase some of these?
|
130 |
+
MICRO_BATCH_SIZE = 8 # this could actually be 5 but i like powers of 2
|
131 |
+
BATCH_SIZE = 128
|
132 |
+
MAX_STEPS = None
|
133 |
+
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
134 |
+
EPOCHS = 3 # we don't always need 3 tbh
|
135 |
+
LEARNING_RATE = 3e-4 # the Karpathy constant
|
136 |
+
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
|
137 |
+
LORA_R = 8
|
138 |
+
LORA_ALPHA = 16
|
139 |
+
LORA_DROPOUT = 0.05
|
140 |
+
VAL_SET_SIZE = args.test_size #2000
|
141 |
+
TARGET_MODULES = [
|
142 |
+
"q_proj",
|
143 |
+
"v_proj",
|
144 |
+
]
|
145 |
+
DATA_PATH = args.data_path
|
146 |
+
OUTPUT_DIR = args.output_path #"lora-Vicuna"
|
147 |
+
|
148 |
+
device_map = {"": 0} #"auto"
|
149 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
150 |
+
ddp = world_size != 1
|
151 |
+
if ddp:
|
152 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
153 |
+
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
154 |
+
print(args.model_path)
|
155 |
+
model = LlamaForCausalLM.from_pretrained(
|
156 |
+
args.model_path,
|
157 |
+
load_in_4bit=True,
|
158 |
+
device_map=device_map,
|
159 |
+
)
|
160 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
161 |
+
args.model_path, add_eos_token=True
|
162 |
+
)
|
163 |
+
|
164 |
+
model.gradient_checkpointing_enable()
|
165 |
+
model = prepare_model_for_kbit_training(model)
|
166 |
+
|
167 |
+
config = LoraConfig(
|
168 |
+
r=LORA_R,
|
169 |
+
lora_alpha=LORA_ALPHA,
|
170 |
+
target_modules=TARGET_MODULES,
|
171 |
+
lora_dropout=LORA_DROPOUT,
|
172 |
+
bias="none",
|
173 |
+
task_type="CAUSAL_LM",
|
174 |
+
)
|
175 |
+
model = get_peft_model(model, config)
|
176 |
+
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
177 |
+
#tokenizer.padding_side = "left" # Allow batched inference
|
178 |
+
|
179 |
+
data = load_dataset("json", data_files=DATA_PATH)
|
180 |
+
import random;start = random.randint(1, 100)
|
181 |
+
examples = Dataset.from_dict(data['train'][start:start+5]).map(generate_and_tokenize_prompt)
|
182 |
+
for example in examples:
|
183 |
+
print(f'>>> prompt example:\n { tokenizer.decode(example["input_ids"]) }')
|
184 |
+
print(f'>>> tokenizer labels: { tokenizer.decode([ 0 if l==-100 else l for l in example["labels"]])}')
|
185 |
+
print(f'>>> tokenizer example: { example["input_ids"][:250] }...{ example["input_ids"][-10:]}')
|
186 |
+
|
187 |
+
now_max_steps = max((len(data["train"]) - VAL_SET_SIZE) // BATCH_SIZE * EPOCHS, EPOCHS)
|
188 |
+
if args.resume_from_checkpoint:
|
189 |
+
if args.lora_remote_checkpoint is not None:
|
190 |
+
snapshot_download(repo_id=args.lora_remote_checkpoint, allow_patterns=["*.pt", "*.bin", "*.json"], local_dir=args.resume_from_checkpoint)
|
191 |
+
# Check the available weights and load them
|
192 |
+
checkpoint_name = os.path.join(
|
193 |
+
args.resume_from_checkpoint, "pytorch_model.bin"
|
194 |
+
) # Full checkpoint
|
195 |
+
if not os.path.exists(checkpoint_name):
|
196 |
+
pytorch_bin_path = checkpoint_name
|
197 |
+
checkpoint_name = os.path.join(
|
198 |
+
args.resume_from_checkpoint, "adapter_model.bin"
|
199 |
+
) # only LoRA model - LoRA config above has to fit
|
200 |
+
if os.path.exists(checkpoint_name):
|
201 |
+
os.rename(checkpoint_name, pytorch_bin_path)
|
202 |
+
warnings.warn("The file name of the lora checkpoint'adapter_model.bin' is replaced with 'pytorch_model.bin'")
|
203 |
+
else:
|
204 |
+
args.resume_from_checkpoint = (
|
205 |
+
None # So the trainer won't try loading its state
|
206 |
+
)
|
207 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
208 |
+
if os.path.exists(checkpoint_name):
|
209 |
+
print(f"Restarting from {checkpoint_name}")
|
210 |
+
adapters_weights = torch.load(checkpoint_name)
|
211 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
212 |
+
else:
|
213 |
+
print(f"Checkpoint {checkpoint_name} not found")
|
214 |
+
|
215 |
+
train_args_path = os.path.join(args.resume_from_checkpoint, "trainer_state.json")
|
216 |
+
|
217 |
+
if os.path.exists(train_args_path):
|
218 |
+
import json
|
219 |
+
base_train_args = json.load(open(train_args_path, 'r'))
|
220 |
+
base_max_steps = base_train_args["max_steps"]
|
221 |
+
resume_scale = base_max_steps / now_max_steps
|
222 |
+
if base_max_steps > now_max_steps:
|
223 |
+
warnings.warn("epoch {} replace to the base_max_steps {}".format(EPOCHS, base_max_steps))
|
224 |
+
EPOCHS = None
|
225 |
+
MAX_STEPS = base_max_steps
|
226 |
+
else:
|
227 |
+
MAX_STEPS = now_max_steps
|
228 |
+
else:
|
229 |
+
MAX_STEPS = now_max_steps
|
230 |
+
|
231 |
+
|
232 |
+
model.print_trainable_parameters()
|
233 |
+
|
234 |
+
|
235 |
+
num_proc = (os.cpu_count())
|
236 |
+
if VAL_SET_SIZE > 0:
|
237 |
+
train_val = data["train"].train_test_split(
|
238 |
+
test_size=VAL_SET_SIZE, shuffle=True, seed=42
|
239 |
+
)
|
240 |
+
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=num_proc)
|
241 |
+
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt, num_proc=num_proc)
|
242 |
+
else:
|
243 |
+
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt, num_proc=num_proc)
|
244 |
+
val_data = None
|
245 |
+
|
246 |
+
trainer = transformers.Trainer(
|
247 |
+
model=model,
|
248 |
+
train_dataset=train_data,
|
249 |
+
eval_dataset=val_data,
|
250 |
+
args=transformers.TrainingArguments(
|
251 |
+
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
252 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
253 |
+
warmup_steps=100,
|
254 |
+
num_train_epochs=EPOCHS,
|
255 |
+
max_steps=MAX_STEPS,
|
256 |
+
learning_rate=LEARNING_RATE,
|
257 |
+
fp16=True,
|
258 |
+
logging_steps=20,
|
259 |
+
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
|
260 |
+
save_strategy="steps",
|
261 |
+
eval_steps=args.eval_steps if VAL_SET_SIZE > 0 else None,
|
262 |
+
save_steps=args.save_steps,
|
263 |
+
output_dir=OUTPUT_DIR,
|
264 |
+
save_total_limit=30,
|
265 |
+
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
|
266 |
+
ddp_find_unused_parameters=False if ddp else None,
|
267 |
+
report_to="wandb" if args.wandb else [],
|
268 |
+
ignore_data_skip=args.ignore_data_skip,
|
269 |
+
optim="paged_adamw_8bit",
|
270 |
+
),
|
271 |
+
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
272 |
+
)
|
273 |
+
model.config.use_cache = False
|
274 |
+
|
275 |
+
old_state_dict = model.state_dict
|
276 |
+
model.state_dict = (
|
277 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
278 |
+
).__get__(model, type(model))
|
279 |
+
|
280 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
281 |
+
model = torch.compile(model)
|
282 |
+
|
283 |
+
print("\n If there's a warning about missing keys above, please disregard :)")
|
284 |
+
|
285 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
286 |
+
|
287 |
+
model.save_pretrained(OUTPUT_DIR)
|
288 |
+
|
finetune_chat.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from peft import (
|
2 |
+
prepare_model_for_int8_training,
|
3 |
+
LoraConfig,
|
4 |
+
PeftModel,
|
5 |
+
get_peft_model,
|
6 |
+
get_peft_model_state_dict,
|
7 |
+
set_peft_model_state_dict,
|
8 |
+
)
|
9 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, TrainerCallback, GenerationConfig
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import bitsandbytes as bnb
|
15 |
+
from datasets import load_dataset, Dataset
|
16 |
+
import transformers
|
17 |
+
from huggingface_hub import snapshot_download
|
18 |
+
import argparse
|
19 |
+
import warnings
|
20 |
+
from tqdm import tqdm
|
21 |
+
from functools import partial
|
22 |
+
import utils
|
23 |
+
import prompt
|
24 |
+
assert (
|
25 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
26 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
27 |
+
|
28 |
+
# 0. prepare args and logger
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument("--wandb", action="store_true", default=False)
|
31 |
+
parser.add_argument("--prompt_type", type=str, default="chat")
|
32 |
+
parser.add_argument("--data_path", type=str, default="merge.json")
|
33 |
+
parser.add_argument("--output_path", type=str, default="lora-Vicuna")
|
34 |
+
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
|
35 |
+
parser.add_argument("--num_epoch", type=int, default=3)
|
36 |
+
parser.add_argument("--micro_batch", type=int, default=4)
|
37 |
+
parser.add_argument("--total_batch", type=int, default=128)
|
38 |
+
parser.add_argument("--log_steps", type=int, default=100)
|
39 |
+
parser.add_argument("--eval_steps", type=int, default=200)
|
40 |
+
parser.add_argument("--save_steps", type=int, default=200)
|
41 |
+
parser.add_argument("--warmup_ratio", type=float, default=0.05)
|
42 |
+
parser.add_argument("--test_size", type=int, default=200)
|
43 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
44 |
+
parser.add_argument("--lora_remote_checkpoint", type=str, default=None)
|
45 |
+
parser.add_argument("--ignore_data_skip", type=bool, default=False)
|
46 |
+
args = parser.parse_args()
|
47 |
+
if not args.wandb:
|
48 |
+
os.environ["WANDB_MODE"] = "disable"
|
49 |
+
MICRO_BATCH_SIZE = args.micro_batch # this could actually be 5 but i like powers of 2
|
50 |
+
BATCH_SIZE = args.total_batch
|
51 |
+
MAX_STEPS = None
|
52 |
+
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
53 |
+
EPOCHS = args.num_epoch
|
54 |
+
LEARNING_RATE = 3e-4 # the Karpathy constant
|
55 |
+
CUTOFF_LEN = 2048
|
56 |
+
LORA_R = 8
|
57 |
+
LORA_ALPHA = 16
|
58 |
+
LORA_DROPOUT = 0.05
|
59 |
+
USE_8bit = True
|
60 |
+
VAL_SET_SIZE = args.test_size # 2000
|
61 |
+
TARGET_MODULES = [
|
62 |
+
"q_proj",
|
63 |
+
"v_proj",
|
64 |
+
"k_proj",
|
65 |
+
"o_proj",
|
66 |
+
"down_proj",
|
67 |
+
"gate_proj",
|
68 |
+
"up_proj",
|
69 |
+
]
|
70 |
+
DATA_PATH = args.data_path
|
71 |
+
OUTPUT_DIR = args.output_path # "lora-Vicuna"
|
72 |
+
|
73 |
+
device_map = "auto"
|
74 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
75 |
+
ddp = world_size != 1
|
76 |
+
if ddp:
|
77 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
78 |
+
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
79 |
+
# we must make sure batch_size and gradient_accumulation_steps not changed for resuming training.
|
80 |
+
if args.resume_from_checkpoint:
|
81 |
+
old_args_path = os.path.join(args.resume_from_checkpoint, 'training_args.bin')
|
82 |
+
if os.path.exists(old_args_path):
|
83 |
+
old_args = torch.load(old_args_path)
|
84 |
+
if MICRO_BATCH_SIZE != old_args.per_device_train_batch_size:
|
85 |
+
raise Exception(
|
86 |
+
f'current micro batch size {MICRO_BATCH_SIZE} is not equal to the old {old_args.per_device_train_batch_size},'
|
87 |
+
' This will cause the trainer skips wrong epochs or steps.'
|
88 |
+
f'please change your micro batch size to {old_args.per_device_train_batch_size}'
|
89 |
+
' or cancel resuming your training'
|
90 |
+
)
|
91 |
+
if GRADIENT_ACCUMULATION_STEPS != old_args.gradient_accumulation_steps:
|
92 |
+
raise Exception(
|
93 |
+
f'current total batch {BATCH_SIZE} is not equal to the old {old_args.gradient_accumulation_steps*old_args.per_device_train_batch_size},'
|
94 |
+
' This will cause the trainer skips wrong epochs or steps.'
|
95 |
+
f'please change your total batch size to {old_args.gradient_accumulation_steps*old_args.per_device_train_batch_size}'
|
96 |
+
' or cancel resuming your training'
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
raise Exception(f'{old_args_path} is not exist!')
|
100 |
+
# checkpoint = os.path.join(args.resume_from_checkpoint, 'pytorch_model.bin')
|
101 |
+
|
102 |
+
logger = utils.set_file_logger(__name__,OUTPUT_DIR)
|
103 |
+
# 1. load dataset
|
104 |
+
logger.info(f'>>> processing data from {DATA_PATH}')
|
105 |
+
logger.info(f'>>> using {args}')
|
106 |
+
|
107 |
+
train_tokenizer = LlamaTokenizer.from_pretrained(args.model_path, add_eos_token=True)
|
108 |
+
assert train_tokenizer.eos_token_id == 2, "Tokenizer eos is wrong!!!"
|
109 |
+
# unk. we want this to be different from the eos token
|
110 |
+
train_tokenizer.pad_token_id = 0
|
111 |
+
# cannot use eos in generation!
|
112 |
+
# tokenizer.padding_side = "left" # Allow batched inference
|
113 |
+
test_tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
|
114 |
+
if args.prompt_type == 'instruct':
|
115 |
+
PROMPT = prompt.instruct_prompt(train_tokenizer, CUTOFF_LEN)
|
116 |
+
elif args.prompt_type == 'chat':
|
117 |
+
PROMPT = prompt.chat_prompt(train_tokenizer,CUTOFF_LEN)
|
118 |
+
else:
|
119 |
+
raise Exception('not support')
|
120 |
+
# check tokenizer
|
121 |
+
data = load_dataset('json', data_files=DATA_PATH)
|
122 |
+
import random;start = random.randint(1, 100)
|
123 |
+
examples = Dataset.from_dict(data['train'][start:start+5]).map(PROMPT.preprocess_train)
|
124 |
+
for example in examples:
|
125 |
+
logger.info(f'>>> using prompt {args.prompt_type}, prompt example:\n { train_tokenizer.decode(example["input_ids"]) }')
|
126 |
+
logger.info(f'>>> tokenizer labels: { train_tokenizer.decode([ 0 if l==-100 else l for l in example["labels"]])}')
|
127 |
+
logger.info(f'>>> tokenizer example: { example["input_ids"][:10] }...{ example["input_ids"][-10:]}')
|
128 |
+
# 2. load model and checkpoints
|
129 |
+
logger.info(f'>>> load model from {args.model_path}')
|
130 |
+
|
131 |
+
if USE_8bit is True:
|
132 |
+
assert bnb.__version__ >= '0.37.2', "Please downgrade bitsandbytes's version, for example: pip install bitsandbytes==0.37.2"
|
133 |
+
model = LlamaForCausalLM.from_pretrained(
|
134 |
+
args.model_path,
|
135 |
+
load_in_8bit=USE_8bit,
|
136 |
+
device_map=device_map,
|
137 |
+
torch_dtype=torch.float16,
|
138 |
+
)
|
139 |
+
if USE_8bit is True:
|
140 |
+
model = prepare_model_for_int8_training(model)
|
141 |
+
config = LoraConfig(
|
142 |
+
r=LORA_R,
|
143 |
+
lora_alpha=LORA_ALPHA,
|
144 |
+
target_modules=TARGET_MODULES,
|
145 |
+
lora_dropout=LORA_DROPOUT,
|
146 |
+
bias="none",
|
147 |
+
task_type="CAUSAL_LM",
|
148 |
+
)
|
149 |
+
model = get_peft_model(model, config)
|
150 |
+
if args.resume_from_checkpoint:
|
151 |
+
checkpoint_name = os.path.join(args.resume_from_checkpoint, "pytorch_model.bin")
|
152 |
+
# adapter_model.bin
|
153 |
+
if not os.path.exists(checkpoint_name):
|
154 |
+
pytorch_bin_path = checkpoint_name
|
155 |
+
checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin")
|
156 |
+
if os.path.exists(checkpoint_name):
|
157 |
+
os.rename(checkpoint_name, pytorch_bin_path)
|
158 |
+
logger.warning("The file name of the lora checkpoint'adapter_model.bin' is replaced with 'pytorch_model.bin'")
|
159 |
+
else:
|
160 |
+
args.resume_from_checkpoint = None # So the trainer won't try loading its state
|
161 |
+
# pytorch_model.bin
|
162 |
+
if os.path.exists(checkpoint_name):
|
163 |
+
logger.info(f'>>> load lora from {checkpoint_name}')
|
164 |
+
adapters_weights = torch.load(checkpoint_name)
|
165 |
+
set_peft_model_state_dict(model, adapters_weights)
|
166 |
+
else:
|
167 |
+
raise Exception(f"Checkpoint {checkpoint_name} not found with resume_from_checkpoint=True!")
|
168 |
+
|
169 |
+
trainable_params = 0
|
170 |
+
all_param = 0
|
171 |
+
for _, param in model.named_parameters():
|
172 |
+
num_params = param.numel()
|
173 |
+
# if using DS Zero 3 and the weights are initialized empty
|
174 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
175 |
+
num_params = param.ds_numel
|
176 |
+
all_param += num_params
|
177 |
+
if param.requires_grad:
|
178 |
+
trainable_params += num_params
|
179 |
+
logger.info(f">>> trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")
|
180 |
+
|
181 |
+
# 3. speedup dataset processing by multi-process
|
182 |
+
num_proc = (os.cpu_count())
|
183 |
+
if VAL_SET_SIZE > 0:
|
184 |
+
train_val = data["train"].train_test_split(test_size=VAL_SET_SIZE, shuffle=True, seed=42)
|
185 |
+
train_data = train_val["train"].shuffle().map(PROMPT.preprocess_train, num_proc=num_proc)
|
186 |
+
val_data = train_val["test"].shuffle().map(PROMPT.preprocess_train, num_proc=num_proc)
|
187 |
+
else:
|
188 |
+
train_data = data["train"].shuffle().map(PROMPT.preprocess_train, num_proc=num_proc)
|
189 |
+
val_data = None
|
190 |
+
now_max_steps = max((len(data["train"]) - VAL_SET_SIZE) // BATCH_SIZE * EPOCHS, EPOCHS)
|
191 |
+
if args.resume_from_checkpoint:
|
192 |
+
# the trainer will ignore the state max_steps and caculate max_steps based on epochs,
|
193 |
+
# so we mannally set the args.max_step to override it.
|
194 |
+
if args.lora_remote_checkpoint is not None:
|
195 |
+
snapshot_download(repo_id=args.lora_remote_checkpoint, allow_patterns=["*.pt", "*.bin", "*.json"], local_dir=args.resume_from_checkpoint)
|
196 |
+
train_state_path = os.path.join(args.resume_from_checkpoint, "trainer_state.json")
|
197 |
+
if os.path.exists(train_state_path):
|
198 |
+
import json
|
199 |
+
base_train_args = json.load(open(train_state_path, 'r'))
|
200 |
+
base_max_steps = base_train_args["max_steps"]
|
201 |
+
resume_scale = base_max_steps / now_max_steps
|
202 |
+
if base_max_steps > now_max_steps:
|
203 |
+
logger.warning(f"epoch {EPOCHS}:{MAX_STEPS} replace to the base_max_steps {base_max_steps}")
|
204 |
+
EPOCHS = None
|
205 |
+
MAX_STEPS = base_max_steps
|
206 |
+
else:
|
207 |
+
MAX_STEPS = now_max_steps
|
208 |
+
assert MAX_STEPS is not None
|
209 |
+
else:
|
210 |
+
MAX_STEPS = now_max_steps
|
211 |
+
|
212 |
+
# 4. start training
|
213 |
+
class CustomCallback(TrainerCallback):
|
214 |
+
|
215 |
+
def __init__(self, trainer) -> None:
|
216 |
+
super().__init__()
|
217 |
+
self.trainer = trainer
|
218 |
+
self.generation_config = GenerationConfig(
|
219 |
+
temperature=1.0,
|
220 |
+
top_p=0.75,
|
221 |
+
top_k=40,
|
222 |
+
num_beams=2,
|
223 |
+
bos_token_id=train_tokenizer.bos_token_id,
|
224 |
+
eos_token_id=train_tokenizer.eos_token_id,
|
225 |
+
pad_token_id=train_tokenizer.pad_token_id,
|
226 |
+
max_new_tokens=1024, # max_length=max_new_tokens+input_sequence
|
227 |
+
min_new_tokens=1, # min_length=min_new_tokens+input_sequence
|
228 |
+
bad_words_ids=test_tokenizer(['\n\nUser:','\n\nAssistant:'], add_special_tokens=False).input_ids
|
229 |
+
)
|
230 |
+
self.repetition_penalty=1.3
|
231 |
+
self.logger = utils.set_file_logger('transformers.trainer', trainer.args.output_dir)
|
232 |
+
|
233 |
+
def on_log(self, args, state, control, logs, **kwargs):
|
234 |
+
logger.info(logs)
|
235 |
+
|
236 |
+
trainer = transformers.Trainer(
|
237 |
+
model=model,
|
238 |
+
train_dataset=train_data,
|
239 |
+
eval_dataset=val_data,
|
240 |
+
args=transformers.TrainingArguments(
|
241 |
+
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
242 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
243 |
+
warmup_ratio=args.warmup_ratio,
|
244 |
+
num_train_epochs=EPOCHS,
|
245 |
+
max_steps=MAX_STEPS,
|
246 |
+
learning_rate=LEARNING_RATE,
|
247 |
+
fp16=True,
|
248 |
+
logging_steps=args.log_steps,
|
249 |
+
logging_first_step=True, # convenient
|
250 |
+
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
|
251 |
+
save_strategy="steps",
|
252 |
+
eval_steps=args.eval_steps if VAL_SET_SIZE > 0 else None,
|
253 |
+
save_steps=args.save_steps,
|
254 |
+
output_dir=OUTPUT_DIR,
|
255 |
+
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
|
256 |
+
ddp_find_unused_parameters=False if ddp else None,
|
257 |
+
report_to="wandb" if args.wandb else [],
|
258 |
+
ignore_data_skip=args.ignore_data_skip,
|
259 |
+
),
|
260 |
+
data_collator=PROMPT.data_collator()
|
261 |
+
)
|
262 |
+
trainer.add_callback(CustomCallback(trainer))
|
263 |
+
model.config.use_cache = False
|
264 |
+
|
265 |
+
old_state_dict = model.state_dict
|
266 |
+
model.state_dict = (
|
267 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
268 |
+
).__get__(model, type(model))
|
269 |
+
|
270 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
271 |
+
model = torch.compile(model)
|
272 |
+
|
273 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
274 |
+
model.save_pretrained(OUTPUT_DIR)
|
finetune_fp16.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import bitsandbytes as bnb
|
7 |
+
from datasets import load_dataset
|
8 |
+
import transformers
|
9 |
+
import argparse
|
10 |
+
import warnings
|
11 |
+
from huggingface_hub import snapshot_download
|
12 |
+
|
13 |
+
assert (
|
14 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
15 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
16 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
17 |
+
from peft import (
|
18 |
+
prepare_model_for_int8_training,
|
19 |
+
LoraConfig,
|
20 |
+
get_peft_model,
|
21 |
+
get_peft_model_state_dict,
|
22 |
+
set_peft_model_state_dict,
|
23 |
+
)
|
24 |
+
|
25 |
+
def get_peft_state_maybe_zero_3(state_dict, bias):
|
26 |
+
if hasattr(param, "ds_id"):
|
27 |
+
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
|
28 |
+
with zero.GatheredParameters([param]):
|
29 |
+
param = param.data.cpu().clone().detach()
|
30 |
+
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
|
31 |
+
return to_return
|
32 |
+
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument("--wandb", action="store_true", default=False)
|
35 |
+
parser.add_argument("--data_path", type=str, default="merge.json")
|
36 |
+
parser.add_argument("--output_path", type=str, default="lora-Vicuna")
|
37 |
+
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
|
38 |
+
parser.add_argument("--eval_steps", type=int, default=200)
|
39 |
+
parser.add_argument("--save_steps", type=int, default=200)
|
40 |
+
parser.add_argument("--test_size", type=int, default=200)
|
41 |
+
parser.add_argument("--resume_from_checkpoint", type=str, default=None)
|
42 |
+
parser.add_argument("--ignore_data_skip", type=str, default="False")
|
43 |
+
parser.add_argument("--lora_remote_checkpoint", type=str, default=None)
|
44 |
+
|
45 |
+
parser.add_argument("--local_rank", type=int, default=-1)
|
46 |
+
parser.add_argument("--deepspeed", action="store_true", default=False)
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
if not args.wandb:
|
51 |
+
os.environ["WANDB_MODE"] = "disable"
|
52 |
+
# optimized for RTX 4090. for larger GPUs, increase some of these?
|
53 |
+
MICRO_BATCH_SIZE = 2 # this could actually be 5 but i like powers of 2
|
54 |
+
BATCH_SIZE = 128
|
55 |
+
MAX_STEPS = None
|
56 |
+
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
57 |
+
EPOCHS = 3 # we don't always need 3 tbh
|
58 |
+
LEARNING_RATE = 3e-4 # the Karpathy constant
|
59 |
+
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
|
60 |
+
LORA_R = 8
|
61 |
+
LORA_ALPHA = 16
|
62 |
+
LORA_DROPOUT = 0.05
|
63 |
+
VAL_SET_SIZE = args.test_size #2000
|
64 |
+
TARGET_MODULES = [
|
65 |
+
"q_proj",
|
66 |
+
"v_proj",
|
67 |
+
]
|
68 |
+
DATA_PATH = args.data_path
|
69 |
+
OUTPUT_DIR = args.output_path #"lora-Vicuna"
|
70 |
+
|
71 |
+
device_map = {"": 0} #"auto"
|
72 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
73 |
+
ddp = world_size != 1
|
74 |
+
if ddp:
|
75 |
+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
76 |
+
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
77 |
+
print(args.model_path)
|
78 |
+
model = LlamaForCausalLM.from_pretrained(
|
79 |
+
args.model_path,
|
80 |
+
load_in_8bit=False,
|
81 |
+
torch_dtype=torch.float16,
|
82 |
+
device_map=device_map,
|
83 |
+
).half()
|
84 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
85 |
+
args.model_path, add_eos_token=True
|
86 |
+
)
|
87 |
+
|
88 |
+
#model = prepare_model_for_int8_training(model)
|
89 |
+
|
90 |
+
config = LoraConfig(
|
91 |
+
r=LORA_R,
|
92 |
+
lora_alpha=LORA_ALPHA,
|
93 |
+
target_modules=TARGET_MODULES,
|
94 |
+
lora_dropout=LORA_DROPOUT,
|
95 |
+
bias="none",
|
96 |
+
task_type="CAUSAL_LM",
|
97 |
+
)
|
98 |
+
model = get_peft_model(model, config)
|
99 |
+
|
100 |
+
tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
101 |
+
#tokenizer.padding_side = "left" # Allow batched inference
|
102 |
+
|
103 |
+
data = load_dataset("json", data_files=DATA_PATH)
|
104 |
+
|
105 |
+
now_max_steps = max((len(data["train"]) - VAL_SET_SIZE) // BATCH_SIZE * EPOCHS, EPOCHS)
|
106 |
+
if args.resume_from_checkpoint:
|
107 |
+
if args.lora_remote_checkpoint is not None:
|
108 |
+
snapshot_download(repo_id=args.lora_remote_checkpoint, allow_patterns=["*.pt", "*.bin", "*.json"], local_dir=args.resume_from_checkpoint)
|
109 |
+
# Check the available weights and load them
|
110 |
+
checkpoint_name = os.path.join(
|
111 |
+
args.resume_from_checkpoint, "pytorch_model.bin"
|
112 |
+
) # Full checkpoint
|
113 |
+
if not os.path.exists(checkpoint_name):
|
114 |
+
pytorch_bin_path = checkpoint_name
|
115 |
+
checkpoint_name = os.path.join(
|
116 |
+
args.resume_from_checkpoint, "adapter_model.bin"
|
117 |
+
) # only LoRA model - LoRA config above has to fit
|
118 |
+
if os.path.exists(checkpoint_name):
|
119 |
+
os.rename(checkpoint_name, pytorch_bin_path)
|
120 |
+
warnings.warn("The file name of the lora checkpoint'adapter_model.bin' is replaced with 'pytorch_model.bin'")
|
121 |
+
else:
|
122 |
+
args.resume_from_checkpoint = (
|
123 |
+
None # So the trainer won't try loading its state
|
124 |
+
)
|
125 |
+
# The two files above have a different name depending on how they were saved, but are actually the same.
|
126 |
+
if os.path.exists(checkpoint_name):
|
127 |
+
print(f"Restarting from {checkpoint_name}")
|
128 |
+
adapters_weights = torch.load(checkpoint_name)
|
129 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
130 |
+
else:
|
131 |
+
print(f"Checkpoint {checkpoint_name} not found")
|
132 |
+
|
133 |
+
train_args_path = os.path.join(args.resume_from_checkpoint, "trainer_state.json")
|
134 |
+
|
135 |
+
if os.path.exists(train_args_path):
|
136 |
+
import json
|
137 |
+
base_train_args = json.load(open(train_args_path, 'r'))
|
138 |
+
base_max_steps = base_train_args["max_steps"]
|
139 |
+
resume_scale = base_max_steps / now_max_steps
|
140 |
+
if base_max_steps > now_max_steps:
|
141 |
+
warnings.warn("epoch {} replace to the base_max_steps {}".format(EPOCHS, base_max_steps))
|
142 |
+
EPOCHS = None
|
143 |
+
MAX_STEPS = base_max_steps
|
144 |
+
else:
|
145 |
+
MAX_STEPS = now_max_steps
|
146 |
+
else:
|
147 |
+
MAX_STEPS = now_max_steps
|
148 |
+
|
149 |
+
|
150 |
+
model.print_trainable_parameters()
|
151 |
+
|
152 |
+
def generate_prompt(data_point):
|
153 |
+
# sorry about the formatting disaster gotta move fast
|
154 |
+
if data_point["input"]:
|
155 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
156 |
+
|
157 |
+
### Instruction:
|
158 |
+
{data_point["instruction"]}
|
159 |
+
|
160 |
+
### Input:
|
161 |
+
{data_point["input"]}
|
162 |
+
|
163 |
+
### Response:
|
164 |
+
{data_point["output"]}"""
|
165 |
+
else:
|
166 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
167 |
+
|
168 |
+
### Instruction:
|
169 |
+
{data_point["instruction"]}
|
170 |
+
|
171 |
+
### Response:
|
172 |
+
{data_point["output"]}"""
|
173 |
+
|
174 |
+
|
175 |
+
def tokenize(prompt):
|
176 |
+
# there's probably a way to do this with the tokenizer settings
|
177 |
+
# but again, gotta move fast
|
178 |
+
result = tokenizer(
|
179 |
+
prompt,
|
180 |
+
truncation=True,
|
181 |
+
max_length=CUTOFF_LEN + 1,
|
182 |
+
padding="max_length",
|
183 |
+
)
|
184 |
+
return {
|
185 |
+
"input_ids": result["input_ids"][:-1],
|
186 |
+
"attention_mask": result["attention_mask"][:-1],
|
187 |
+
}
|
188 |
+
|
189 |
+
|
190 |
+
def generate_and_tokenize_prompt(data_point):
|
191 |
+
# This function masks out the labels for the input,
|
192 |
+
# so that our loss is computed only on the response.
|
193 |
+
user_prompt = (
|
194 |
+
(
|
195 |
+
f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
196 |
+
|
197 |
+
### Instruction:
|
198 |
+
{data_point["instruction"]}
|
199 |
+
|
200 |
+
### Input:
|
201 |
+
{data_point["input"]}
|
202 |
+
|
203 |
+
### Response:
|
204 |
+
"""
|
205 |
+
)
|
206 |
+
if data_point["input"]
|
207 |
+
else (
|
208 |
+
f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
209 |
+
|
210 |
+
### Instruction:
|
211 |
+
{data_point["instruction"]}
|
212 |
+
|
213 |
+
### Response:
|
214 |
+
"""
|
215 |
+
)
|
216 |
+
)
|
217 |
+
len_user_prompt_tokens = (
|
218 |
+
len(
|
219 |
+
tokenizer(
|
220 |
+
user_prompt,
|
221 |
+
truncation=True,
|
222 |
+
max_length=CUTOFF_LEN + 1,
|
223 |
+
)["input_ids"]
|
224 |
+
)
|
225 |
+
- 1
|
226 |
+
) # no eos token
|
227 |
+
full_tokens = tokenizer(
|
228 |
+
user_prompt + data_point["output"],
|
229 |
+
truncation=True,
|
230 |
+
max_length=CUTOFF_LEN + 1,
|
231 |
+
padding="max_length",
|
232 |
+
)["input_ids"][:-1]
|
233 |
+
return {
|
234 |
+
"input_ids": full_tokens,
|
235 |
+
"labels": [-100] * len_user_prompt_tokens
|
236 |
+
+ full_tokens[len_user_prompt_tokens:],
|
237 |
+
"attention_mask": [1] * (len(full_tokens)),
|
238 |
+
}
|
239 |
+
|
240 |
+
|
241 |
+
if VAL_SET_SIZE > 0:
|
242 |
+
train_val = data["train"].train_test_split(
|
243 |
+
test_size=VAL_SET_SIZE, shuffle=True, seed=42
|
244 |
+
)
|
245 |
+
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
246 |
+
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
247 |
+
else:
|
248 |
+
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
249 |
+
val_data = None
|
250 |
+
trainer = transformers.Trainer(
|
251 |
+
model=model,
|
252 |
+
train_dataset=train_data,
|
253 |
+
eval_dataset=val_data,
|
254 |
+
args=transformers.TrainingArguments(
|
255 |
+
per_device_train_batch_size=MICRO_BATCH_SIZE,
|
256 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
257 |
+
warmup_steps=100,
|
258 |
+
num_train_epochs=EPOCHS,
|
259 |
+
max_steps=MAX_STEPS,
|
260 |
+
learning_rate=LEARNING_RATE,
|
261 |
+
fp16=True,
|
262 |
+
logging_steps=20,
|
263 |
+
evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
|
264 |
+
save_strategy="steps",
|
265 |
+
eval_steps=args.eval_steps if VAL_SET_SIZE > 0 else None,
|
266 |
+
save_steps=args.save_steps,
|
267 |
+
output_dir=OUTPUT_DIR,
|
268 |
+
save_total_limit=30,
|
269 |
+
load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
|
270 |
+
ddp_find_unused_parameters=False if ddp else None,
|
271 |
+
report_to="wandb" if args.wandb else [],
|
272 |
+
ignore_data_skip=args.ignore_data_skip,
|
273 |
+
deepspeed="sample/zero_config.json" if args.deepspeed else None,
|
274 |
+
),
|
275 |
+
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
276 |
+
)
|
277 |
+
model.config.use_cache = False
|
278 |
+
|
279 |
+
old_state_dict = model.state_dict
|
280 |
+
model.state_dict = (
|
281 |
+
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
282 |
+
).__get__(model, type(model))
|
283 |
+
|
284 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
285 |
+
model = torch.compile(model)
|
286 |
+
|
287 |
+
print("\n If there's a warning about missing keys above, please disregard :)")
|
288 |
+
|
289 |
+
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
|
290 |
+
|
291 |
+
model.save_pretrained(OUTPUT_DIR)
|
292 |
+
|
generate.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from peft import PeftModel, PeftModelForCausalLM, LoraConfig
|
4 |
+
import transformers
|
5 |
+
import gradio as gr
|
6 |
+
import argparse
|
7 |
+
import warnings
|
8 |
+
import os
|
9 |
+
from utils import StreamPeftGenerationMixin,StreamLlamaForCausalLM
|
10 |
+
assert (
|
11 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
12 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
13 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--model_path", type=str, default="/model/13B_hf")
|
17 |
+
parser.add_argument("--lora_path", type=str, default="checkpoint-3000")
|
18 |
+
parser.add_argument("--use_typewriter", type=int, default=1)
|
19 |
+
parser.add_argument("--use_local", type=int, default=1)
|
20 |
+
args = parser.parse_args()
|
21 |
+
print(args)
|
22 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
|
23 |
+
|
24 |
+
LOAD_8BIT = True
|
25 |
+
BASE_MODEL = args.model_path
|
26 |
+
LORA_WEIGHTS = args.lora_path
|
27 |
+
|
28 |
+
|
29 |
+
# fix the path for local checkpoint
|
30 |
+
lora_bin_path = os.path.join(args.lora_path, "adapter_model.bin")
|
31 |
+
print(lora_bin_path)
|
32 |
+
if not os.path.exists(lora_bin_path) and args.use_local:
|
33 |
+
pytorch_bin_path = os.path.join(args.lora_path, "pytorch_model.bin")
|
34 |
+
print(pytorch_bin_path)
|
35 |
+
if os.path.exists(pytorch_bin_path):
|
36 |
+
os.rename(pytorch_bin_path, lora_bin_path)
|
37 |
+
warnings.warn(
|
38 |
+
"The file name of the lora checkpoint'pytorch_model.bin' is replaced with 'adapter_model.bin'"
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
assert ('Checkpoint is not Found!')
|
42 |
+
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
device = "cuda"
|
45 |
+
else:
|
46 |
+
device = "cpu"
|
47 |
+
|
48 |
+
try:
|
49 |
+
if torch.backends.mps.is_available():
|
50 |
+
device = "mps"
|
51 |
+
except:
|
52 |
+
pass
|
53 |
+
|
54 |
+
if device == "cuda":
|
55 |
+
model = LlamaForCausalLM.from_pretrained(
|
56 |
+
BASE_MODEL,
|
57 |
+
load_in_8bit=LOAD_8BIT,
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
device_map="auto", #device_map={"": 0},
|
60 |
+
)
|
61 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
62 |
+
model, LORA_WEIGHTS, torch_dtype=torch.float16, device_map="auto", #device_map={"": 0}
|
63 |
+
)
|
64 |
+
elif device == "mps":
|
65 |
+
model = LlamaForCausalLM.from_pretrained(
|
66 |
+
BASE_MODEL,
|
67 |
+
device_map={"": device},
|
68 |
+
torch_dtype=torch.float16,
|
69 |
+
)
|
70 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
71 |
+
model,
|
72 |
+
LORA_WEIGHTS,
|
73 |
+
device_map={"": device},
|
74 |
+
torch_dtype=torch.float16,
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
model = LlamaForCausalLM.from_pretrained(
|
78 |
+
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
|
79 |
+
)
|
80 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
81 |
+
model,
|
82 |
+
LORA_WEIGHTS,
|
83 |
+
device_map={"": device},
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def generate_prompt(instruction, input=None):
|
88 |
+
if input:
|
89 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
90 |
+
|
91 |
+
### Instruction:
|
92 |
+
{instruction}
|
93 |
+
|
94 |
+
### Input:
|
95 |
+
{input}
|
96 |
+
|
97 |
+
### Response:"""
|
98 |
+
else:
|
99 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
100 |
+
|
101 |
+
### Instruction:
|
102 |
+
{instruction}
|
103 |
+
|
104 |
+
### Response:"""
|
105 |
+
|
106 |
+
|
107 |
+
if not LOAD_8BIT:
|
108 |
+
model.half() # seems to fix bugs for some users.
|
109 |
+
|
110 |
+
model.eval()
|
111 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
112 |
+
model = torch.compile(model)
|
113 |
+
|
114 |
+
|
115 |
+
def evaluate(
|
116 |
+
input,
|
117 |
+
temperature=0.1,
|
118 |
+
top_p=0.75,
|
119 |
+
top_k=40,
|
120 |
+
num_beams=4,
|
121 |
+
max_new_tokens=128,
|
122 |
+
min_new_tokens=1,
|
123 |
+
repetition_penalty=2.0,
|
124 |
+
**kwargs,
|
125 |
+
):
|
126 |
+
prompt = generate_prompt(input)
|
127 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
128 |
+
input_ids = inputs["input_ids"].to(device)
|
129 |
+
generation_config = GenerationConfig(
|
130 |
+
temperature=temperature,
|
131 |
+
top_p=top_p,
|
132 |
+
top_k=top_k,
|
133 |
+
num_beams=num_beams,
|
134 |
+
bos_token_id=1,
|
135 |
+
eos_token_id=2,
|
136 |
+
pad_token_id=0,
|
137 |
+
max_new_tokens=max_new_tokens, # max_length=max_new_tokens+input_sequence
|
138 |
+
min_new_tokens=min_new_tokens, # min_length=min_new_tokens+input_sequence
|
139 |
+
**kwargs,
|
140 |
+
)
|
141 |
+
with torch.no_grad():
|
142 |
+
if args.use_typewriter:
|
143 |
+
for generation_output in model.stream_generate(
|
144 |
+
input_ids=input_ids,
|
145 |
+
generation_config=generation_config,
|
146 |
+
return_dict_in_generate=True,
|
147 |
+
output_scores=False,
|
148 |
+
repetition_penalty=float(repetition_penalty),
|
149 |
+
):
|
150 |
+
outputs = tokenizer.batch_decode(generation_output)
|
151 |
+
show_text = "\n--------------------------------------------\n".join(
|
152 |
+
[output.split("### Response:")[1].strip().replace('�','')+" ▌" for output in outputs]
|
153 |
+
)
|
154 |
+
# if show_text== '':
|
155 |
+
# yield last_show_text
|
156 |
+
# else:
|
157 |
+
yield show_text
|
158 |
+
yield outputs[0].split("### Response:")[1].strip().replace('�','')
|
159 |
+
else:
|
160 |
+
generation_output = model.generate(
|
161 |
+
input_ids=input_ids,
|
162 |
+
generation_config=generation_config,
|
163 |
+
return_dict_in_generate=True,
|
164 |
+
output_scores=False,
|
165 |
+
repetition_penalty=1.3,
|
166 |
+
)
|
167 |
+
output = generation_output.sequences[0]
|
168 |
+
output = tokenizer.decode(output).split("### Response:")[1].strip()
|
169 |
+
print(output)
|
170 |
+
yield output
|
171 |
+
|
172 |
+
|
173 |
+
gr.Interface(
|
174 |
+
fn=evaluate,
|
175 |
+
inputs=[
|
176 |
+
gr.components.Textbox(
|
177 |
+
lines=2, label="Input", placeholder="Tell me about alpacas."
|
178 |
+
),
|
179 |
+
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
|
180 |
+
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
|
181 |
+
gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
|
182 |
+
gr.components.Slider(minimum=1, maximum=10, step=1, value=4, label="Beams Number"),
|
183 |
+
gr.components.Slider(
|
184 |
+
minimum=1, maximum=2000, step=1, value=256, label="Max New Tokens"
|
185 |
+
),
|
186 |
+
gr.components.Slider(
|
187 |
+
minimum=1, maximum=300, step=1, value=1, label="Min New Tokens"
|
188 |
+
),
|
189 |
+
gr.components.Slider(
|
190 |
+
minimum=0.1, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"
|
191 |
+
),
|
192 |
+
],
|
193 |
+
outputs=[
|
194 |
+
gr.inputs.Textbox(
|
195 |
+
lines=25,
|
196 |
+
label="Output",
|
197 |
+
)
|
198 |
+
],
|
199 |
+
title="Chinese-Vicuna 中文小羊驼",
|
200 |
+
description="中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来,模型基于开源的llama7B,主要贡献是对应的lora模型。由于代码训练资源要求较小,希望为llama中文lora社区做一份贡献。",
|
201 |
+
).queue().launch(share=True)
|
generate_4bit.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from peft import PeftModel, PeftModelForCausalLM, LoraConfig
|
4 |
+
import transformers
|
5 |
+
import gradio as gr
|
6 |
+
import argparse
|
7 |
+
import warnings
|
8 |
+
import os
|
9 |
+
from utils import StreamPeftGenerationMixin,StreamLlamaForCausalLM
|
10 |
+
assert (
|
11 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
12 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
13 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig, BitsAndBytesConfig
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--model_path", type=str, default="/model/13B_hf")
|
17 |
+
parser.add_argument("--lora_path", type=str, default="checkpoint-3000")
|
18 |
+
parser.add_argument("--use_typewriter", type=int, default=1)
|
19 |
+
parser.add_argument("--use_local", type=int, default=1)
|
20 |
+
args = parser.parse_args()
|
21 |
+
print(args)
|
22 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
|
23 |
+
|
24 |
+
LOAD_8BIT = True
|
25 |
+
BASE_MODEL = args.model_path
|
26 |
+
LORA_WEIGHTS = args.lora_path
|
27 |
+
|
28 |
+
|
29 |
+
# fix the path for local checkpoint
|
30 |
+
lora_bin_path = os.path.join(args.lora_path, "adapter_model.bin")
|
31 |
+
print(lora_bin_path)
|
32 |
+
if not os.path.exists(lora_bin_path) and args.use_local:
|
33 |
+
pytorch_bin_path = os.path.join(args.lora_path, "pytorch_model.bin")
|
34 |
+
print(pytorch_bin_path)
|
35 |
+
if os.path.exists(pytorch_bin_path):
|
36 |
+
os.rename(pytorch_bin_path, lora_bin_path)
|
37 |
+
warnings.warn(
|
38 |
+
"The file name of the lora checkpoint'pytorch_model.bin' is replaced with 'adapter_model.bin'"
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
assert ('Checkpoint is not Found!')
|
42 |
+
|
43 |
+
if torch.cuda.is_available():
|
44 |
+
device = "cuda"
|
45 |
+
else:
|
46 |
+
device = "cpu"
|
47 |
+
|
48 |
+
try:
|
49 |
+
if torch.backends.mps.is_available():
|
50 |
+
device = "mps"
|
51 |
+
except:
|
52 |
+
pass
|
53 |
+
|
54 |
+
bnb_config = BitsAndBytesConfig(
|
55 |
+
load_in_4bit=True,
|
56 |
+
bnb_4bit_use_double_quant=True,
|
57 |
+
bnb_4bit_quant_type="nf4",
|
58 |
+
bnb_4bit_compute_dtype=torch.float16
|
59 |
+
)
|
60 |
+
|
61 |
+
if device == "cuda":
|
62 |
+
model = LlamaForCausalLM.from_pretrained(
|
63 |
+
BASE_MODEL,
|
64 |
+
quantization_config=bnb_config,
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
device_map="auto", #{"": 0},
|
67 |
+
)
|
68 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
69 |
+
model, LORA_WEIGHTS, torch_dtype=torch.float16, device_map="auto", #{"": 0}
|
70 |
+
)
|
71 |
+
elif device == "mps":
|
72 |
+
model = LlamaForCausalLM.from_pretrained(
|
73 |
+
BASE_MODEL,
|
74 |
+
device_map={"": device},
|
75 |
+
torch_dtype=torch.float16,
|
76 |
+
)
|
77 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
78 |
+
model,
|
79 |
+
LORA_WEIGHTS,
|
80 |
+
device_map={"": device},
|
81 |
+
torch_dtype=torch.float16,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
model = LlamaForCausalLM.from_pretrained(
|
85 |
+
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
|
86 |
+
)
|
87 |
+
model = StreamPeftGenerationMixin.from_pretrained(
|
88 |
+
model,
|
89 |
+
LORA_WEIGHTS,
|
90 |
+
device_map={"": device},
|
91 |
+
)
|
92 |
+
|
93 |
+
model.config.bos_token_id = tokenizer.bos_token_id = 1
|
94 |
+
model.config.eos_token_id = tokenizer.eos_token_id = 2
|
95 |
+
|
96 |
+
def generate_prompt(instruction, input=None):
|
97 |
+
if input:
|
98 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
99 |
+
|
100 |
+
### Instruction:
|
101 |
+
{instruction}
|
102 |
+
|
103 |
+
### Input:
|
104 |
+
{input}
|
105 |
+
|
106 |
+
### Response:"""
|
107 |
+
else:
|
108 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
109 |
+
|
110 |
+
### Instruction:
|
111 |
+
{instruction}
|
112 |
+
|
113 |
+
### Response:"""
|
114 |
+
|
115 |
+
|
116 |
+
if not LOAD_8BIT:
|
117 |
+
model.half() # seems to fix bugs for some users.
|
118 |
+
|
119 |
+
model.eval()
|
120 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
121 |
+
model = torch.compile(model)
|
122 |
+
|
123 |
+
|
124 |
+
def evaluate(
|
125 |
+
input,
|
126 |
+
temperature=0.1,
|
127 |
+
top_p=0.75,
|
128 |
+
top_k=40,
|
129 |
+
num_beams=4,
|
130 |
+
max_new_tokens=128,
|
131 |
+
min_new_tokens=1,
|
132 |
+
repetition_penalty=2.0,
|
133 |
+
**kwargs,
|
134 |
+
):
|
135 |
+
prompt = generate_prompt(input)
|
136 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
137 |
+
input_ids = inputs["input_ids"].to(device)
|
138 |
+
generation_config = GenerationConfig(
|
139 |
+
temperature=temperature,
|
140 |
+
top_p=top_p,
|
141 |
+
top_k=top_k,
|
142 |
+
num_beams=num_beams,
|
143 |
+
bos_token_id=1,
|
144 |
+
eos_token_id=2,
|
145 |
+
pad_token_id=0,
|
146 |
+
max_new_tokens=max_new_tokens, # max_length=max_new_tokens+input_sequence
|
147 |
+
min_new_tokens=min_new_tokens, # min_length=min_new_tokens+input_sequence
|
148 |
+
**kwargs,
|
149 |
+
)
|
150 |
+
with torch.no_grad():
|
151 |
+
if args.use_typewriter:
|
152 |
+
for generation_output in model.stream_generate(
|
153 |
+
input_ids=input_ids,
|
154 |
+
generation_config=generation_config,
|
155 |
+
return_dict_in_generate=True,
|
156 |
+
output_scores=False,
|
157 |
+
repetition_penalty=float(repetition_penalty),
|
158 |
+
):
|
159 |
+
outputs = tokenizer.batch_decode(generation_output)
|
160 |
+
show_text = "\n--------------------------------------------\n".join(
|
161 |
+
[output.split("### Response:")[1].strip().replace('�','')+" ▌" for output in outputs]
|
162 |
+
)
|
163 |
+
# if show_text== '':
|
164 |
+
# yield last_show_text
|
165 |
+
# else:
|
166 |
+
yield show_text
|
167 |
+
yield outputs[0].split("### Response:")[1].strip().replace('�','')
|
168 |
+
else:
|
169 |
+
generation_output = model.generate(
|
170 |
+
input_ids=input_ids,
|
171 |
+
generation_config=generation_config,
|
172 |
+
return_dict_in_generate=True,
|
173 |
+
output_scores=False,
|
174 |
+
repetition_penalty=1.3,
|
175 |
+
)
|
176 |
+
output = generation_output.sequences[0]
|
177 |
+
output = tokenizer.decode(output).split("### Response:")[1].strip()
|
178 |
+
print(output)
|
179 |
+
yield output
|
180 |
+
|
181 |
+
|
182 |
+
gr.Interface(
|
183 |
+
fn=evaluate,
|
184 |
+
inputs=[
|
185 |
+
gr.components.Textbox(
|
186 |
+
lines=2, label="Input", placeholder="Tell me about alpacas."
|
187 |
+
),
|
188 |
+
gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
|
189 |
+
gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
|
190 |
+
gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
|
191 |
+
gr.components.Slider(minimum=1, maximum=10, step=1, value=4, label="Beams Number"),
|
192 |
+
gr.components.Slider(
|
193 |
+
minimum=1, maximum=2000, step=1, value=256, label="Max New Tokens"
|
194 |
+
),
|
195 |
+
gr.components.Slider(
|
196 |
+
minimum=1, maximum=300, step=1, value=1, label="Min New Tokens"
|
197 |
+
),
|
198 |
+
gr.components.Slider(
|
199 |
+
minimum=0.1, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"
|
200 |
+
),
|
201 |
+
],
|
202 |
+
outputs=[
|
203 |
+
gr.inputs.Textbox(
|
204 |
+
lines=25,
|
205 |
+
label="Output",
|
206 |
+
)
|
207 |
+
],
|
208 |
+
title="Chinese-Vicuna 中文小羊驼",
|
209 |
+
description="中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来,模型基于开源的llama7B,主要贡献是对应的lora模型。由于代码训练资源要求较小,希望为llama中文lora社区做一份贡献。",
|
210 |
+
).queue().launch(share=True)
|
interaction.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from peft import PeftModel
|
4 |
+
import transformers
|
5 |
+
import gradio as gr
|
6 |
+
import argparse
|
7 |
+
import warnings
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
assert (
|
12 |
+
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
13 |
+
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
14 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
15 |
+
|
16 |
+
parser = argparse.ArgumentParser()
|
17 |
+
parser.add_argument("--model_path", type=str, default="decapoda-research/llama-7b-hf")
|
18 |
+
parser.add_argument("--lora_path", type=str, default="./lora-Vicuna/checkpoint-final")
|
19 |
+
parser.add_argument("--use_local", type=int, default=1)
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
tokenizer = LlamaTokenizer.from_pretrained(args.model_path)
|
23 |
+
|
24 |
+
LOAD_8BIT = True
|
25 |
+
BASE_MODEL = args.model_path
|
26 |
+
LORA_WEIGHTS = args.lora_path
|
27 |
+
|
28 |
+
# fix the path for local checkpoint
|
29 |
+
lora_bin_path = os.path.join(args.lora_path, "adapter_model.bin")
|
30 |
+
print(lora_bin_path)
|
31 |
+
if not os.path.exists(lora_bin_path) and args.use_local:
|
32 |
+
pytorch_bin_path = os.path.join(args.lora_path, "pytorch_model.bin")
|
33 |
+
print(pytorch_bin_path)
|
34 |
+
if os.path.exists(pytorch_bin_path):
|
35 |
+
os.rename(pytorch_bin_path, lora_bin_path)
|
36 |
+
warnings.warn("The file name of the lora checkpoint'pytorch_model.bin' is replaced with 'adapter_model.bin'")
|
37 |
+
else:
|
38 |
+
assert ('Checkpoint is not Found!')
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
device = "cuda"
|
41 |
+
else:
|
42 |
+
device = "cpu"
|
43 |
+
|
44 |
+
try:
|
45 |
+
if torch.backends.mps.is_available():
|
46 |
+
device = "mps"
|
47 |
+
except:
|
48 |
+
pass
|
49 |
+
|
50 |
+
if device == "cuda":
|
51 |
+
model = LlamaForCausalLM.from_pretrained(
|
52 |
+
BASE_MODEL,
|
53 |
+
load_in_8bit=LOAD_8BIT,
|
54 |
+
torch_dtype=torch.float16,
|
55 |
+
device_map="auto", #device_map={"": 0},
|
56 |
+
)
|
57 |
+
model = PeftModel.from_pretrained(
|
58 |
+
model,
|
59 |
+
LORA_WEIGHTS,
|
60 |
+
torch_dtype=torch.float16,
|
61 |
+
device_map="auto", #device_map={"": 0},
|
62 |
+
)
|
63 |
+
elif device == "mps":
|
64 |
+
model = LlamaForCausalLM.from_pretrained(
|
65 |
+
BASE_MODEL,
|
66 |
+
device_map={"": device},
|
67 |
+
torch_dtype=torch.float16,
|
68 |
+
)
|
69 |
+
model = PeftModel.from_pretrained(
|
70 |
+
model,
|
71 |
+
LORA_WEIGHTS,
|
72 |
+
device_map={"": device},
|
73 |
+
torch_dtype=torch.float16,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
model = LlamaForCausalLM.from_pretrained(
|
77 |
+
BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
|
78 |
+
)
|
79 |
+
model = PeftModel.from_pretrained(
|
80 |
+
model,
|
81 |
+
LORA_WEIGHTS,
|
82 |
+
device_map={"": device},
|
83 |
+
)
|
84 |
+
|
85 |
+
def generate_prompt(instruction, input=None):
|
86 |
+
if input:
|
87 |
+
return f"""The following is a conversation between an AI assistant called Assistant and a human user called User.
|
88 |
+
|
89 |
+
### Instruction:
|
90 |
+
{instruction}
|
91 |
+
|
92 |
+
### Input:
|
93 |
+
{input}
|
94 |
+
|
95 |
+
### Response:"""
|
96 |
+
else:
|
97 |
+
return f"""The following is a conversation between an AI assistant called Assistant and a human user called User.
|
98 |
+
|
99 |
+
### Instruction:
|
100 |
+
{instruction}
|
101 |
+
|
102 |
+
### Response:"""
|
103 |
+
|
104 |
+
if not LOAD_8BIT:
|
105 |
+
model.half() # seems to fix bugs for some users.
|
106 |
+
|
107 |
+
model.eval()
|
108 |
+
if torch.__version__ >= "2" and sys.platform != "win32":
|
109 |
+
model = torch.compile(model)
|
110 |
+
|
111 |
+
def interaction(
|
112 |
+
input,
|
113 |
+
history,
|
114 |
+
temperature=0.1,
|
115 |
+
top_p=0.75,
|
116 |
+
top_k=40,
|
117 |
+
num_beams=4,
|
118 |
+
max_new_tokens=128,
|
119 |
+
repetition_penalty=1.0,
|
120 |
+
max_memory=256,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
now_input = input
|
124 |
+
history = history or []
|
125 |
+
if len(history) != 0:
|
126 |
+
input = "\n".join(["User:" + i[0]+"\n"+"Assistant:" + i[1] for i in history]) + "\n" + "User:" + input
|
127 |
+
if len(input) > max_memory:
|
128 |
+
input = input[-max_memory:]
|
129 |
+
print(input)
|
130 |
+
print(len(input))
|
131 |
+
prompt = generate_prompt(input)
|
132 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
133 |
+
input_ids = inputs["input_ids"].to(device)
|
134 |
+
generation_config = GenerationConfig(
|
135 |
+
temperature=temperature,
|
136 |
+
top_p=top_p,
|
137 |
+
top_k=top_k,
|
138 |
+
num_beams=num_beams,
|
139 |
+
**kwargs,
|
140 |
+
)
|
141 |
+
with torch.no_grad():
|
142 |
+
generation_output = model.generate(
|
143 |
+
input_ids=input_ids,
|
144 |
+
generation_config=generation_config,
|
145 |
+
return_dict_in_generate=True,
|
146 |
+
output_scores=True,
|
147 |
+
max_new_tokens=max_new_tokens,
|
148 |
+
repetition_penalty=float(repetition_penalty),
|
149 |
+
)
|
150 |
+
s = generation_output.sequences[0]
|
151 |
+
output = tokenizer.decode(s)
|
152 |
+
output = output.split("### Response:")[1].strip()
|
153 |
+
output = output.replace("Belle", "Vicuna")
|
154 |
+
if 'User:' in output:
|
155 |
+
output = output.split("User:")[0]
|
156 |
+
history.append((now_input, output))
|
157 |
+
print(history)
|
158 |
+
return history, history
|
159 |
+
|
160 |
+
chatbot = gr.Chatbot().style(color_map=("green", "pink"))
|
161 |
+
demo = gr.Interface(
|
162 |
+
fn=interaction,
|
163 |
+
inputs=[
|
164 |
+
gr.components.Textbox(
|
165 |
+
lines=2, label="Input", placeholder="Tell me about alpacas."
|
166 |
+
),
|
167 |
+
"state",
|
168 |
+
gr.components.Slider(minimum=0, maximum=1, value=1.0, label="Temperature"),
|
169 |
+
gr.components.Slider(minimum=0, maximum=1, value=0.9, label="Top p"),
|
170 |
+
gr.components.Slider(minimum=0, maximum=100, step=1, value=60, label="Top k"),
|
171 |
+
gr.components.Slider(minimum=1, maximum=5, step=1, value=2, label="Beams"),
|
172 |
+
gr.components.Slider(
|
173 |
+
minimum=1, maximum=2000, step=1, value=128, label="Max new tokens"
|
174 |
+
),
|
175 |
+
gr.components.Slider(
|
176 |
+
minimum=0.1, maximum=10.0, step=0.1, value=2.0, label="Repetition Penalty"
|
177 |
+
),
|
178 |
+
gr.components.Slider(
|
179 |
+
minimum=0, maximum=2000, step=1, value=256, label="max memory"
|
180 |
+
),
|
181 |
+
],
|
182 |
+
outputs=[chatbot, "state"],
|
183 |
+
allow_flagging="auto",
|
184 |
+
title="Chinese-Vicuna 中文小羊驼",
|
185 |
+
description="中文小羊驼由各种高质量的开源instruction数据集,结合Alpaca-lora的代码训练而来,模型基于开源的llama7B,主要贡献是对应的lora模型。由于代码训练资源要求较小,希望为llama中文lora社区做一份贡献。",
|
186 |
+
)
|
187 |
+
demo.queue().launch(share=True, inbrowser=True)
|
prompt.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import transformers
|
3 |
+
from utils import printf
|
4 |
+
import copy
|
5 |
+
|
6 |
+
class prompt:
|
7 |
+
def __init__(self, tokenizer, max_len, add_eos=True):
|
8 |
+
self.tokenizer = tokenizer
|
9 |
+
self.max_len = max_len
|
10 |
+
self.add_eos=add_eos
|
11 |
+
|
12 |
+
class instruct_prompt(prompt):
|
13 |
+
prompt = (
|
14 |
+
"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
15 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
16 |
+
)
|
17 |
+
prompt_input = (
|
18 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
19 |
+
"### Instruction:{instruction}\n\n### Input:{input}\n\n### Response:"
|
20 |
+
)
|
21 |
+
prompt_history = "User:{input}\n\nAssistant:{output}\n\n"
|
22 |
+
prompt_post = "User:{input}\n\nAssistant:"
|
23 |
+
|
24 |
+
def preprocess_gen(self, data_point):
|
25 |
+
if 'history' not in data_point:
|
26 |
+
# single instruction format {'instruction':..,'input':..}
|
27 |
+
if 'input' in data_point:
|
28 |
+
user_prompt = self.prompt_input.format_map(data_point)
|
29 |
+
else:
|
30 |
+
user_prompt = self.prompt.format_map(data_point)
|
31 |
+
else:
|
32 |
+
# multi turn format {'history':[..], 'input':[..]}
|
33 |
+
user_prompt = "\n".join(["User:" + i['input']+"\n"+"Assistant:" + i['output'] for i in data_point['history']]) + "\nUser:" + data_point['input'] + "\nAssistant:"
|
34 |
+
user_prompt = user_prompt[-self.max_len:]
|
35 |
+
user_prompt=self.prompt.format_map({'instruction':user_prompt})
|
36 |
+
input_ids = self.tokenizer(user_prompt)["input_ids"]
|
37 |
+
return input_ids
|
38 |
+
|
39 |
+
def preprocess_train(self, data_point):
|
40 |
+
# single instruction format {'instruction':..,'input':..,'output':..}
|
41 |
+
if 'instruction' in data_point:
|
42 |
+
if 'input' in data_point:
|
43 |
+
user_prompt = self.prompt_input.format_map(data_point)
|
44 |
+
else:
|
45 |
+
user_prompt = self.prompt.format_map(data_point)
|
46 |
+
output = data_point["output"]
|
47 |
+
# multi turn format {'input':[..], 'output':[..]}
|
48 |
+
else:
|
49 |
+
user_prompt = ''
|
50 |
+
lens = len(data_point['input'])
|
51 |
+
for i in range(lens-1):
|
52 |
+
user_prompt += self.prompt_history.format_map({'input':data_point['input'][i],'output':data_point['output'][i]})
|
53 |
+
user_prompt += self.prompt_post.format_map({'input':data_point['input'][-1]})
|
54 |
+
user_prompt = self.prompt.format_map({'instruction': user_prompt})
|
55 |
+
output = data_point['output'][-1]
|
56 |
+
|
57 |
+
len_user_prompt_tokens = (len(self.tokenizer(
|
58 |
+
user_prompt,
|
59 |
+
truncation=True,
|
60 |
+
max_length=self.max_len + 1,
|
61 |
+
)["input_ids"])- 1) # no eos token
|
62 |
+
full_tokens = self.tokenizer(
|
63 |
+
user_prompt + output,
|
64 |
+
truncation=True,
|
65 |
+
max_length=self.max_len + 1,
|
66 |
+
padding="max_length",
|
67 |
+
)["input_ids"][:-1]
|
68 |
+
return {
|
69 |
+
"input_ids": full_tokens,
|
70 |
+
"labels": [-100] * len_user_prompt_tokens
|
71 |
+
+ full_tokens[len_user_prompt_tokens:],
|
72 |
+
"attention_mask": [1] * (len(full_tokens)),
|
73 |
+
}
|
74 |
+
|
75 |
+
def data_collator(self,):
|
76 |
+
return transformers.DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
|
77 |
+
|
78 |
+
def postprocess(self, text, render=True):
|
79 |
+
#import pdb;pdb.set_trace()
|
80 |
+
printf(text)
|
81 |
+
output = text.split("### Response:")[1].strip()
|
82 |
+
output = output.replace("Belle", "Vicuna")
|
83 |
+
printf(output)
|
84 |
+
if '###' in output:
|
85 |
+
output = output.split("###")[0]
|
86 |
+
if 'User' in output:
|
87 |
+
output = output.split("User")[0]
|
88 |
+
output = output.replace('�','').replace('</s>', '')
|
89 |
+
if render:
|
90 |
+
# fix gradio chatbot markdown code render bug
|
91 |
+
lines = output.split("\n")
|
92 |
+
for i, line in enumerate(lines):
|
93 |
+
if "```" in line:
|
94 |
+
if line != "```":
|
95 |
+
lines[i] = f'<pre><code class="language-{lines[i][3:]}">'
|
96 |
+
else:
|
97 |
+
lines[i] = '</code></pre>'
|
98 |
+
else:
|
99 |
+
if i > 0:
|
100 |
+
lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">").replace("__", '\_\_')
|
101 |
+
output = "".join(lines)
|
102 |
+
# output = output.replace('<br/><pre>','\n<pre>') work for html; but not for gradio
|
103 |
+
return output
|
104 |
+
|
105 |
+
class chat_prompt(prompt):
|
106 |
+
prompt_pre = (
|
107 |
+
"The following is a conversation between an AI assistant called Assistant and a human user called User. "
|
108 |
+
"The assistant is intelligent, knowledgeable and polite to answer questions of user.\n\n"
|
109 |
+
)
|
110 |
+
prompt_history = "User:{input}\n\nAssistant:{output}\n\n"
|
111 |
+
prompt_post = "User:{input}\n\nAssistant:"
|
112 |
+
|
113 |
+
def preprocess_gen(self, data_point):
|
114 |
+
user_prompt = self.prompt_pre
|
115 |
+
len_avail = self.max_len - len(self.tokenizer(user_prompt, add_special_tokens=False)['input_ids'])
|
116 |
+
input_prompt = self.prompt_post.format_map({'input':data_point['input']})
|
117 |
+
len_avail -= len(self.tokenizer(input_prompt, add_special_tokens=False)['input_ids'])
|
118 |
+
lens = len(data_point['history'])
|
119 |
+
tokenized_lens = []
|
120 |
+
for i in range(lens):
|
121 |
+
tmp_prompt = self.prompt_history.format_map(data_point['history'][i])
|
122 |
+
tokenized_lens.append(len(self.tokenizer(tmp_prompt,add_special_tokens=False)["input_ids"]))
|
123 |
+
|
124 |
+
# 启发式:/2 优先除前面的
|
125 |
+
i = 0
|
126 |
+
while sum(tokenized_lens) > len_avail and i < lens:
|
127 |
+
history = data_point['history'][i]
|
128 |
+
tmp_len1 = len(history['input'])
|
129 |
+
tmp_len2 = len(history['output'])
|
130 |
+
if tmp_len2 > tmp_len1:
|
131 |
+
history['output'] = history['output'][:tmp_len2//2]
|
132 |
+
else:
|
133 |
+
history['input'] = history['input'][:tmp_len1//2]
|
134 |
+
prompt = self.prompt_history.format_map(history)
|
135 |
+
single_len =(len(self.tokenizer(prompt,add_special_tokens=False)["input_ids"]))
|
136 |
+
tokenized_lens[i] = single_len
|
137 |
+
i += 1
|
138 |
+
total_len = sum(tokenized_lens)
|
139 |
+
# 还不够的话 直接截断
|
140 |
+
while total_len > len_avail and i < lens - 1 :
|
141 |
+
total_len -= tokenized_lens[i]
|
142 |
+
data_point['history'] = data_point['history'][1:]
|
143 |
+
i += 1
|
144 |
+
# 最终合并
|
145 |
+
for i in range(lens):
|
146 |
+
user_prompt += self.prompt_history.format_map(data_point['history'][i])
|
147 |
+
user_prompt += input_prompt
|
148 |
+
printf({'real_input:':user_prompt})
|
149 |
+
inputs = self.tokenizer(user_prompt)["input_ids"]
|
150 |
+
return inputs
|
151 |
+
|
152 |
+
def preprocess_train(self, data_point):
|
153 |
+
user_prompt = self.prompt_pre
|
154 |
+
lens = len(data_point['input'])
|
155 |
+
for i in range(lens-1):
|
156 |
+
user_prompt += self.prompt_history.format_map({'input':data_point['input'][i].strip(),'output':data_point['output'][i].strip()})
|
157 |
+
user_prompt += self.prompt_post.format_map({'input':data_point['input'][-1].strip()})
|
158 |
+
|
159 |
+
len_user_prompt_tokens = len(self.tokenizer(
|
160 |
+
user_prompt,
|
161 |
+
truncation=True,
|
162 |
+
max_length=self.max_len,
|
163 |
+
)["input_ids"]) - 1 # remove extra eos
|
164 |
+
if self.add_eos:
|
165 |
+
full_tokens = self.tokenizer(
|
166 |
+
user_prompt + data_point["output"][-1].strip(),
|
167 |
+
truncation=True,
|
168 |
+
padding=False,
|
169 |
+
max_length=self.max_len,
|
170 |
+
)["input_ids"] # need eos
|
171 |
+
else:
|
172 |
+
full_tokens = self.tokenizer(
|
173 |
+
user_prompt + data_point["output"][-1].strip(),
|
174 |
+
truncation=True,
|
175 |
+
padding=False,
|
176 |
+
max_length=self.max_len+1,
|
177 |
+
)["input_ids"][:-1] # delete eos
|
178 |
+
return {
|
179 |
+
"input_ids": full_tokens,
|
180 |
+
"labels": [-100] * len_user_prompt_tokens + full_tokens[len_user_prompt_tokens:],
|
181 |
+
"attention_mask": [1] * (len(full_tokens)),
|
182 |
+
}
|
183 |
+
|
184 |
+
def data_collator(self,):
|
185 |
+
return transformers.DataCollatorForSeq2Seq(self.tokenizer)
|
186 |
+
|
187 |
+
def postprocess(self, text, render=False):
|
188 |
+
output = text.split("Assistant:")[-1].strip()
|
189 |
+
if 'User:' in output:
|
190 |
+
output = output.split("User:")[0]
|
191 |
+
output = output.replace('�','')
|
192 |
+
if render:
|
193 |
+
# fix gradio chatbot markdown code render bug
|
194 |
+
lines = output.split("\n")
|
195 |
+
for i, line in enumerate(lines):
|
196 |
+
if "```" in line:
|
197 |
+
if line != "```":
|
198 |
+
lines[i] = f'<pre><code class="language-{lines[i][3:]}">'
|
199 |
+
else:
|
200 |
+
lines[i] = '</code></pre>'
|
201 |
+
else:
|
202 |
+
if i > 0:
|
203 |
+
lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">").replace("__", '\_\_')
|
204 |
+
output = "".join(lines)
|
205 |
+
# output = output.replace('<br/><pre>','\n<pre>') work for html; but not for gradio
|
206 |
+
return output
|
207 |
+
|
208 |
+
def get_data_collator():
|
209 |
+
return transformers.DataCollatorForLanguageModeling
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.15.0
|
2 |
+
appdirs==1.4.4
|
3 |
+
bitsandbytes==0.37.0
|
4 |
+
datasets==2.8.0
|
5 |
+
deepspeed==0.8.3
|
6 |
+
evaluate==0.4.0
|
7 |
+
fairscale==0.4.13
|
8 |
+
torch==1.13.1
|
9 |
+
torchvision==0.14.1
|
10 |
+
gradio==3.20.0
|
11 |
+
huggingface-hub==0.13.3
|
12 |
+
loralib==0.1.1
|
13 |
+
nvitop==1.0.0
|
14 |
+
peft @ git+https://github.com/huggingface/peft.git@13e53fc7ee5d89d59b16523051006dddf0fb7a49
|
15 |
+
sentencepiece==0.1.96
|
16 |
+
tensorboard==2.12.0
|
17 |
+
texttable==1.6.7
|
18 |
+
tokenizers==0.13.2
|
19 |
+
tqdm==4.65.0
|
20 |
+
transformers @ git+https://github.com/huggingface/transformers.git@0dcb46e7a4a9e587ba84ff35778ab4233a184c11
|
21 |
+
trlx @ git+https://github.com/CarperAI/trlx.git@b91da7b03d8e9fa0c0d6dce10a8f2611aca3013f
|
22 |
+
wandb==0.13.10
|
23 |
+
triton==2.0.0
|
requirements_4bit.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate @ git+https://github.com/huggingface/accelerate.git@f9b2e6769b1279b11335d2b87801fa1ca81bb056
|
2 |
+
appdirs==1.4.4
|
3 |
+
bitsandbytes==0.39.0
|
4 |
+
datasets==2.8.0
|
5 |
+
deepspeed==0.8.3
|
6 |
+
evaluate==0.4.0
|
7 |
+
fairscale==0.4.13
|
8 |
+
torch==1.13.1
|
9 |
+
torchvision==0.14.1
|
10 |
+
gradio==3.20.0
|
11 |
+
huggingface-hub==0.13.3
|
12 |
+
loralib==0.1.1
|
13 |
+
nvitop==1.0.0
|
14 |
+
peft @ git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
|
15 |
+
sentencepiece==0.1.96
|
16 |
+
tensorboard==2.12.0
|
17 |
+
texttable==1.6.7
|
18 |
+
tokenizers==0.13.2
|
19 |
+
tqdm==4.65.0
|
20 |
+
transformers @ git+https://github.com/huggingface/transformers@e45e756d22206ca8fa9fb057c8c3d8fa79bf81c6
|
21 |
+
trlx @ git+https://github.com/CarperAI/trlx.git@b91da7b03d8e9fa0c0d6dce10a8f2611aca3013f
|
22 |
+
wandb==0.13.10
|
23 |
+
triton==2.0.0
|
test_tokenizer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import transformers
|
5 |
+
import argparse
|
6 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig
|
7 |
+
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--model_path", type=str, default="yahma/llama-7b-hf") #yahma/llama-7b-hf #decapoda-research/llama-7b-hf
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
14 |
+
args.model_path, add_eos_token=True
|
15 |
+
)
|
16 |
+
|
17 |
+
test_text = ["Hello, nice to meet you!", "你好很高兴能见到你!"]
|
18 |
+
|
19 |
+
for text in test_text:
|
20 |
+
input_ids = tokenizer.encode(text)
|
21 |
+
print(f"input_ids: {input_ids}")
|
22 |
+
decode_text = tokenizer.decode(input_ids)
|
23 |
+
print(f"decode_text: {decode_text}")
|
24 |
+
|
25 |
+
"""
|
26 |
+
Correct ==> yahma/llama-7b-hf + newest Transformers(>=4.28.1):
|
27 |
+
> !!! Beginning with 1 (bos), ending with 2 (eos) !!!
|
28 |
+
|
29 |
+
input_ids: [1, 15043, 29892, 7575, 304, 5870, 366, 29991, 2]
|
30 |
+
decode_text: <s> Hello, nice to meet you!</s>
|
31 |
+
input_ids: [1, 29871, 30919, 31076, 232, 193, 139, 30528, 31914, 30815, 235, 170, 132, 30780, 30919, 30584, 2]
|
32 |
+
decode_text: <s> 你好很高兴能见到你!</s>
|
33 |
+
|
34 |
+
Correct ==> decapoda-research/llama-7b-hf + Old Transformers like our version(transformers @ git+https://github.com/huggingface/transformers.git@0dcb46e7a4a9e587ba84ff35778ab4233a184c11)
|
35 |
+
input_ids: [1, 15043, 29892, 7575, 304, 5870, 366, 29991, 2]
|
36 |
+
decode_text: Hello, nice to meet you!
|
37 |
+
input_ids: [1, 29871, 30919, 31076, 232, 193, 139, 30528, 31914, 30815, 235, 170, 132, 30780, 30919, 30584, 2]
|
38 |
+
decode_text: 你好很高兴能见到你!
|
39 |
+
|
40 |
+
Correct ==> decapoda-research/llama-7b-hf + Old Transformers like our version(transformers @ git+https://github.com/huggingface/transformers.git@0dcb46e7a4a9e587ba84ff35778ab4233a184c11)
|
41 |
+
input_ids: [1, 15043, 29892, 7575, 304, 5870, 366, 29991, 2]
|
42 |
+
decode_text: Hello, nice to meet you!
|
43 |
+
input_ids: [1, 29871, 30919, 31076, 232, 193, 139, 30528, 31914, 30815, 235, 170, 132, 30780, 30919, 30584, 2]
|
44 |
+
decode_text: 你好很高兴能见到你!
|
45 |
+
|
46 |
+
|
47 |
+
老版本transformers的问题:代码默认加载tokenizer.model
|
48 |
+
新版本transformers的修改:新版本默认加载config
|
49 |
+
|
50 |
+
decapoda-research:config的bos=0,eos=1(×),tokenizer.model是正确的
|
51 |
+
yahma:config的bos=1,eos=2,tokenizer.model是正确的
|
52 |
+
"""
|
utils.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import json
|
6 |
+
from typing import Optional, Tuple, Union, List, Callable
|
7 |
+
from transformers import LlamaForCausalLM
|
8 |
+
from transformers.generation.logits_process import LogitsProcessor
|
9 |
+
from transformers.generation.beam_search import BeamSearchScorer
|
10 |
+
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
11 |
+
from transformers.generation.utils import (
|
12 |
+
LogitsProcessorList,
|
13 |
+
StoppingCriteriaList,
|
14 |
+
GenerationConfig,
|
15 |
+
GenerationMixin,
|
16 |
+
)
|
17 |
+
import warnings
|
18 |
+
from peft import PeftModel, PeftModelForCausalLM, LoraConfig
|
19 |
+
import peft
|
20 |
+
import torch.distributed as dist
|
21 |
+
from torch import nn
|
22 |
+
import copy
|
23 |
+
from accelerate.hooks import (
|
24 |
+
AlignDevicesHook,
|
25 |
+
add_hook_to_module,
|
26 |
+
remove_hook_from_submodules,
|
27 |
+
)
|
28 |
+
from accelerate.utils import get_balanced_memory
|
29 |
+
from huggingface_hub import hf_hub_download
|
30 |
+
from accelerate import dispatch_model, infer_auto_device_map
|
31 |
+
from peft.utils import PeftType, set_peft_model_state_dict
|
32 |
+
|
33 |
+
def printf(*args,**kargs):
|
34 |
+
if os.environ.get('DEBUG',False):
|
35 |
+
end = '\n'
|
36 |
+
if 'end' in kargs:
|
37 |
+
end = kargs['end']
|
38 |
+
print(*args, end=end, flush=True)
|
39 |
+
|
40 |
+
class ColorFormatter(logging.Formatter):
|
41 |
+
|
42 |
+
grey = "\x1b[38;20m"
|
43 |
+
blue = "\x1b[34;20m"
|
44 |
+
yellow = "\x1b[33;20m"
|
45 |
+
red = "\x1b[31;20m"
|
46 |
+
bold_red = "\x1b[31;1m"
|
47 |
+
reset = "\x1b[0m"
|
48 |
+
|
49 |
+
def __init__(self, fmt):
|
50 |
+
super().__init__(fmt)
|
51 |
+
self.FORMATS = {
|
52 |
+
logging.DEBUG: self.grey + fmt + self.reset,
|
53 |
+
logging.INFO: self.blue + fmt + self.reset,
|
54 |
+
logging.WARNING: self.yellow + fmt + self.reset,
|
55 |
+
logging.ERROR: self.red + fmt + self.reset,
|
56 |
+
logging.CRITICAL: self.bold_red + fmt + self.reset
|
57 |
+
}
|
58 |
+
|
59 |
+
def format(self, record):
|
60 |
+
log_fmt = self.FORMATS.get(record.levelno)
|
61 |
+
formatter = logging.Formatter(log_fmt)
|
62 |
+
return formatter.format(record)
|
63 |
+
|
64 |
+
def set_console_logger(name):
|
65 |
+
logger = logging.getLogger(name)
|
66 |
+
logger.setLevel(logging.DEBUG)
|
67 |
+
consoleHandler = logging.StreamHandler(sys.stdout)
|
68 |
+
consoleHandler.setLevel(logging.INFO)
|
69 |
+
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s"))
|
70 |
+
logger.addHandler(consoleHandler)
|
71 |
+
return logger
|
72 |
+
|
73 |
+
def set_file_logger(name, dir, use_console=False):
|
74 |
+
logger = logging.getLogger(name)
|
75 |
+
logger.setLevel(logging.DEBUG)
|
76 |
+
os.makedirs(dir, exist_ok=True)
|
77 |
+
|
78 |
+
if use_console:
|
79 |
+
logger.propagate = False # disable default handler
|
80 |
+
consoleHandler = logging.StreamHandler(sys.stdout)
|
81 |
+
consoleHandler.setLevel(logging.INFO)
|
82 |
+
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s"))
|
83 |
+
logger.addHandler(consoleHandler)
|
84 |
+
|
85 |
+
fileHandler = logging.FileHandler(os.path.join(dir,'session.log'), mode='a')
|
86 |
+
fileHandler.setLevel(logging.INFO)
|
87 |
+
fileHandler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s %(message)s"))
|
88 |
+
logger.addHandler(fileHandler)
|
89 |
+
return logger
|
90 |
+
|
91 |
+
def to_jsonl(data, path):
|
92 |
+
with open(path, 'a') as f:
|
93 |
+
for line in data:
|
94 |
+
f.write(json.dumps(line,ensure_ascii=False)+'\n')
|
95 |
+
|
96 |
+
def from_json(path):
|
97 |
+
return json.load(open(path))
|
98 |
+
|
99 |
+
def from_jsonl(path):
|
100 |
+
return [json.loads(line) for line in open(path, 'r') ]
|
101 |
+
|
102 |
+
def to_json(data, path):
|
103 |
+
json.dump(data, open(path, 'w'), ensure_ascii=False)
|
104 |
+
|
105 |
+
class StreamGenerationMixin(GenerationMixin):
|
106 |
+
# support for streamly generation
|
107 |
+
# TODO: group_beam_search
|
108 |
+
@torch.no_grad()
|
109 |
+
def stream_generate(
|
110 |
+
self,
|
111 |
+
input_ids: Optional[torch.Tensor] = None,
|
112 |
+
generation_config: Optional[GenerationConfig] = None,
|
113 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
114 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
115 |
+
prefix_allowed_tokens_fn: Optional[
|
116 |
+
Callable[[int, torch.Tensor], List[int]]
|
117 |
+
] = None,
|
118 |
+
**kwargs,
|
119 |
+
):
|
120 |
+
if is_deepspeed_zero3_enabled() and dist.world_size() > 1:
|
121 |
+
synced_gpus = True
|
122 |
+
else:
|
123 |
+
synced_gpus = False
|
124 |
+
|
125 |
+
if kwargs.get("attention_mask", None) is not None:
|
126 |
+
# concat prompt attention mask
|
127 |
+
prefix_attention_mask = torch.ones(
|
128 |
+
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens
|
129 |
+
).to(kwargs["input_ids"].device)
|
130 |
+
kwargs["attention_mask"] = torch.cat(
|
131 |
+
(prefix_attention_mask, kwargs["attention_mask"]), dim=1
|
132 |
+
)
|
133 |
+
if kwargs.get("position_ids", None) is not None:
|
134 |
+
warnings.warn(
|
135 |
+
"Position ids are not supported for parameter efficient tuning. Ignoring position ids."
|
136 |
+
)
|
137 |
+
kwargs["position_ids"] = None
|
138 |
+
if kwargs.get("token_type_ids", None) is not None:
|
139 |
+
warnings.warn(
|
140 |
+
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids"
|
141 |
+
)
|
142 |
+
kwargs["token_type_ids"] = None
|
143 |
+
|
144 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
145 |
+
|
146 |
+
if generation_config is None:
|
147 |
+
generation_config = self.generation_config
|
148 |
+
generation_config = copy.deepcopy(generation_config)
|
149 |
+
model_kwargs = generation_config.update(**kwargs)
|
150 |
+
|
151 |
+
bos_token_id, eos_token_id, pad_token_id = (
|
152 |
+
generation_config.bos_token_id,
|
153 |
+
generation_config.eos_token_id,
|
154 |
+
generation_config.pad_token_id,
|
155 |
+
)
|
156 |
+
|
157 |
+
if isinstance(eos_token_id, int):
|
158 |
+
eos_token_id = [eos_token_id]
|
159 |
+
|
160 |
+
has_default_max_length = (
|
161 |
+
kwargs.get("max_length") is None
|
162 |
+
and generation_config.max_length is not None
|
163 |
+
)
|
164 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
165 |
+
warnings.warn(
|
166 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
167 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
168 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
169 |
+
UserWarning,
|
170 |
+
)
|
171 |
+
elif generation_config.max_new_tokens is not None:
|
172 |
+
generation_config.max_length = (
|
173 |
+
generation_config.max_new_tokens + input_ids_seq_length
|
174 |
+
)
|
175 |
+
if generation_config.min_new_tokens is not None:
|
176 |
+
generation_config.min_length = (
|
177 |
+
generation_config.min_new_tokens + input_ids_seq_length
|
178 |
+
)
|
179 |
+
|
180 |
+
if input_ids_seq_length >= generation_config.max_length:
|
181 |
+
input_ids_string = (
|
182 |
+
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
183 |
+
)
|
184 |
+
|
185 |
+
# 2. Set generation parameters if not already defined
|
186 |
+
logits_processor = (
|
187 |
+
logits_processor if logits_processor is not None else LogitsProcessorList()
|
188 |
+
)
|
189 |
+
stopping_criteria = (
|
190 |
+
stopping_criteria
|
191 |
+
if stopping_criteria is not None
|
192 |
+
else StoppingCriteriaList()
|
193 |
+
)
|
194 |
+
# 7. determine generation mode
|
195 |
+
is_constraint_gen_mode = (
|
196 |
+
generation_config.constraints is not None or generation_config.force_words_ids is not None
|
197 |
+
)
|
198 |
+
|
199 |
+
is_contrastive_search_gen_mode = (
|
200 |
+
generation_config.top_k is not None
|
201 |
+
and generation_config.top_k > 1
|
202 |
+
and generation_config.do_sample is False
|
203 |
+
and generation_config.penalty_alpha is not None
|
204 |
+
and generation_config.penalty_alpha > 0
|
205 |
+
)
|
206 |
+
|
207 |
+
is_greedy_gen_mode = (
|
208 |
+
(generation_config.num_beams == 1)
|
209 |
+
and (generation_config.num_beam_groups == 1)
|
210 |
+
and generation_config.do_sample is False
|
211 |
+
and not is_constraint_gen_mode
|
212 |
+
and not is_contrastive_search_gen_mode
|
213 |
+
)
|
214 |
+
# beam=1 and do_sample=True
|
215 |
+
is_sample_gen_mode = (
|
216 |
+
(generation_config.num_beams == 1)
|
217 |
+
and (generation_config.num_beam_groups == 1)
|
218 |
+
and generation_config.do_sample is True
|
219 |
+
and not is_constraint_gen_mode
|
220 |
+
and not is_contrastive_search_gen_mode
|
221 |
+
)
|
222 |
+
is_beam_gen_mode = (
|
223 |
+
(generation_config.num_beams > 1)
|
224 |
+
and (generation_config.num_beam_groups == 1)
|
225 |
+
and generation_config.do_sample is False
|
226 |
+
and not is_constraint_gen_mode
|
227 |
+
and not is_contrastive_search_gen_mode
|
228 |
+
)
|
229 |
+
is_beam_sample_gen_mode = (
|
230 |
+
(generation_config.num_beams > 1)
|
231 |
+
and (generation_config.num_beam_groups == 1)
|
232 |
+
and generation_config.do_sample is True
|
233 |
+
and not is_constraint_gen_mode
|
234 |
+
and not is_contrastive_search_gen_mode
|
235 |
+
)
|
236 |
+
is_group_beam_gen_mode = (
|
237 |
+
(generation_config.num_beams > 1)
|
238 |
+
and (generation_config.num_beam_groups > 1)
|
239 |
+
and not is_constraint_gen_mode
|
240 |
+
and not is_contrastive_search_gen_mode
|
241 |
+
)
|
242 |
+
# 8. prepare distribution pre_processing samplers
|
243 |
+
logits_processor = self._get_logits_processor(
|
244 |
+
generation_config=generation_config,
|
245 |
+
input_ids_seq_length=input_ids_seq_length,
|
246 |
+
encoder_input_ids=input_ids,
|
247 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
248 |
+
logits_processor=logits_processor,
|
249 |
+
)
|
250 |
+
# 9. prepare stopping criteria
|
251 |
+
stopping_criteria = self._get_stopping_criteria(
|
252 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
253 |
+
)
|
254 |
+
logits_warper = self._get_logits_warper(generation_config)
|
255 |
+
|
256 |
+
if is_greedy_gen_mode:
|
257 |
+
# 11. run greedy search
|
258 |
+
return self.stream_greedy_search(
|
259 |
+
input_ids,
|
260 |
+
logits_processor,
|
261 |
+
stopping_criteria,
|
262 |
+
generation_config,
|
263 |
+
synced_gpus,
|
264 |
+
**model_kwargs,
|
265 |
+
)
|
266 |
+
elif is_sample_gen_mode:
|
267 |
+
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
268 |
+
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
269 |
+
input_ids=input_ids,
|
270 |
+
expand_size=generation_config.num_return_sequences,
|
271 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
272 |
+
**model_kwargs,
|
273 |
+
)
|
274 |
+
return self.stream_sample(
|
275 |
+
generation_config,
|
276 |
+
input_ids,
|
277 |
+
logits_processor,
|
278 |
+
logits_warper,
|
279 |
+
stopping_criteria,
|
280 |
+
synced_gpus,
|
281 |
+
**model_kwargs,
|
282 |
+
)
|
283 |
+
elif is_beam_gen_mode:
|
284 |
+
return self.stream_beam_search(
|
285 |
+
generation_config,
|
286 |
+
input_ids,
|
287 |
+
logits_processor,
|
288 |
+
stopping_criteria,
|
289 |
+
synced_gpus,
|
290 |
+
**model_kwargs,
|
291 |
+
)
|
292 |
+
elif is_beam_sample_gen_mode:
|
293 |
+
# interleave input_ids with `num_beams` additional sequences per batch
|
294 |
+
return self.stream_beam_sample(
|
295 |
+
input_ids,
|
296 |
+
logits_processor,
|
297 |
+
logits_warper,
|
298 |
+
stopping_criteria,
|
299 |
+
generation_config,
|
300 |
+
synced_gpus,
|
301 |
+
**model_kwargs,
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
raise Exception('not implement')
|
305 |
+
|
306 |
+
def stream_sample(
|
307 |
+
self,
|
308 |
+
generation_config,
|
309 |
+
input_ids,
|
310 |
+
logits_processor,
|
311 |
+
logits_warper,
|
312 |
+
stopping_criteria,
|
313 |
+
synced_gpus,
|
314 |
+
**model_kwargs,
|
315 |
+
):
|
316 |
+
bos_token_id, eos_token_id, pad_token_id = (
|
317 |
+
generation_config.bos_token_id,
|
318 |
+
generation_config.eos_token_id,
|
319 |
+
generation_config.pad_token_id,
|
320 |
+
)
|
321 |
+
if isinstance(eos_token_id, int):
|
322 |
+
eos_token_id = [eos_token_id]
|
323 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
324 |
+
# keep track of which sequences are already finished
|
325 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
326 |
+
this_peer_finished = False # used by synced_gpus only
|
327 |
+
scores=()
|
328 |
+
# auto-regressive generation
|
329 |
+
while True:
|
330 |
+
if synced_gpus:
|
331 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
332 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
333 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
334 |
+
# send 0.0 if we finished, 1.0 otherwise
|
335 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
336 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
337 |
+
if this_peer_finished_flag.item() == 0.0:
|
338 |
+
break
|
339 |
+
# prepare model inputs
|
340 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
341 |
+
# forward pass to get next token
|
342 |
+
outputs = self(
|
343 |
+
**model_inputs,
|
344 |
+
return_dict=True,
|
345 |
+
)
|
346 |
+
if synced_gpus and this_peer_finished:
|
347 |
+
continue # don't waste resources running the code we don't need
|
348 |
+
next_token_logits = outputs.logits[:, -1, :]
|
349 |
+
# pre-process distribution
|
350 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
351 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
352 |
+
|
353 |
+
# sample
|
354 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
355 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
356 |
+
|
357 |
+
# finished sentences should have their next token be a padding token
|
358 |
+
if eos_token_id is not None:
|
359 |
+
if pad_token_id is None:
|
360 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
361 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
362 |
+
|
363 |
+
# update generated ids, model inputs, and length for next step
|
364 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
365 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
366 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
367 |
+
)
|
368 |
+
yield input_ids
|
369 |
+
# if eos_token was found in one sentence, set sentence to finished
|
370 |
+
if eos_token_id_tensor is not None:
|
371 |
+
unfinished_sequences = unfinished_sequences.mul(
|
372 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
373 |
+
)
|
374 |
+
|
375 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
376 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
377 |
+
if not synced_gpus:
|
378 |
+
break
|
379 |
+
else:
|
380 |
+
this_peer_finished = True
|
381 |
+
yield input_ids
|
382 |
+
|
383 |
+
def stream_beam_sample(
|
384 |
+
self,
|
385 |
+
input_ids,
|
386 |
+
logits_processor,
|
387 |
+
logits_warper,
|
388 |
+
stopping_criteria,
|
389 |
+
generation_config,
|
390 |
+
synced_gpus,
|
391 |
+
**model_kwargs,
|
392 |
+
):
|
393 |
+
bos_token_id, eos_token_id, pad_token_id = (
|
394 |
+
generation_config.bos_token_id,
|
395 |
+
generation_config.eos_token_id,
|
396 |
+
generation_config.pad_token_id,
|
397 |
+
)
|
398 |
+
if isinstance(eos_token_id, int):
|
399 |
+
eos_token_id = [eos_token_id]
|
400 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
401 |
+
num_beams = generation_config.num_beams
|
402 |
+
batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1]
|
403 |
+
beam_scorer = BeamSearchScorer(
|
404 |
+
batch_size=batch_size,
|
405 |
+
num_beams=generation_config.num_beams,
|
406 |
+
device=input_ids.device,
|
407 |
+
length_penalty=generation_config.length_penalty,
|
408 |
+
do_early_stopping=generation_config.early_stopping,
|
409 |
+
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
410 |
+
max_length=generation_config.max_length,
|
411 |
+
)
|
412 |
+
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
413 |
+
input_ids=input_ids,
|
414 |
+
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
|
415 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
416 |
+
**model_kwargs,
|
417 |
+
)
|
418 |
+
scores = ()
|
419 |
+
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
420 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
421 |
+
|
422 |
+
this_peer_finished = False # used by synced_gpus only
|
423 |
+
while True:
|
424 |
+
if synced_gpus:
|
425 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
426 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
427 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
428 |
+
# send 0.0 if we finished, 1.0 otherwise
|
429 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
430 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
431 |
+
if this_peer_finished_flag.item() == 0.0:
|
432 |
+
break
|
433 |
+
|
434 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
435 |
+
outputs = self(
|
436 |
+
**model_inputs,
|
437 |
+
return_dict=True,
|
438 |
+
)
|
439 |
+
|
440 |
+
if synced_gpus and this_peer_finished:
|
441 |
+
cur_len = cur_len + 1
|
442 |
+
continue # don't waste resources running the code we don't need
|
443 |
+
|
444 |
+
next_token_logits = outputs.logits[:, -1, :]
|
445 |
+
|
446 |
+
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
447 |
+
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
|
448 |
+
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
|
449 |
+
next_token_scores = nn.functional.log_softmax(
|
450 |
+
next_token_logits, dim=-1
|
451 |
+
) # (batch_size * num_beams, vocab_size)
|
452 |
+
|
453 |
+
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
454 |
+
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
|
455 |
+
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
|
456 |
+
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
|
457 |
+
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
|
458 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
459 |
+
|
460 |
+
# reshape for beam search
|
461 |
+
vocab_size = next_token_scores.shape[-1]
|
462 |
+
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
463 |
+
|
464 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
465 |
+
|
466 |
+
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
|
467 |
+
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
|
468 |
+
|
469 |
+
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
|
470 |
+
next_tokens = torch.gather(next_tokens, -1, _indices)
|
471 |
+
|
472 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
473 |
+
next_tokens = next_tokens % vocab_size
|
474 |
+
|
475 |
+
# stateless
|
476 |
+
beam_outputs = beam_scorer.process(
|
477 |
+
input_ids,
|
478 |
+
next_token_scores,
|
479 |
+
next_tokens,
|
480 |
+
next_indices,
|
481 |
+
pad_token_id=pad_token_id,
|
482 |
+
eos_token_id=eos_token_id,
|
483 |
+
beam_indices=None,
|
484 |
+
)
|
485 |
+
beam_scores = beam_outputs["next_beam_scores"]
|
486 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
487 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
488 |
+
|
489 |
+
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
|
490 |
+
yield input_ids
|
491 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
492 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
493 |
+
)
|
494 |
+
if model_kwargs["past_key_values"] is not None:
|
495 |
+
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
|
496 |
+
|
497 |
+
# increase cur_len
|
498 |
+
cur_len = cur_len + 1
|
499 |
+
|
500 |
+
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
|
501 |
+
if not synced_gpus:
|
502 |
+
break
|
503 |
+
else:
|
504 |
+
this_peer_finished = True
|
505 |
+
|
506 |
+
sequence_outputs = beam_scorer.finalize(
|
507 |
+
input_ids,
|
508 |
+
beam_scores,
|
509 |
+
next_tokens,
|
510 |
+
next_indices,
|
511 |
+
pad_token_id=pad_token_id,
|
512 |
+
eos_token_id=eos_token_id,
|
513 |
+
max_length=stopping_criteria.max_length,
|
514 |
+
beam_indices=None,
|
515 |
+
)
|
516 |
+
yield sequence_outputs["sequences"]
|
517 |
+
|
518 |
+
def stream_greedy_search(
|
519 |
+
self,
|
520 |
+
input_ids,
|
521 |
+
logits_processor,
|
522 |
+
stopping_criteria,
|
523 |
+
generation_config,
|
524 |
+
synced_gpus,
|
525 |
+
**model_kwargs,
|
526 |
+
):
|
527 |
+
# init values
|
528 |
+
bos_token_id, eos_token_id, pad_token_id = (
|
529 |
+
generation_config.bos_token_id,
|
530 |
+
generation_config.eos_token_id,
|
531 |
+
generation_config.pad_token_id,
|
532 |
+
)
|
533 |
+
if isinstance(eos_token_id, int):
|
534 |
+
eos_token_id = [eos_token_id]
|
535 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
536 |
+
# init attention / hidden states / scores tuples
|
537 |
+
scores = ()
|
538 |
+
# keep track of which sequences are already finished
|
539 |
+
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
540 |
+
this_peer_finished = False # used by synced_gpus only
|
541 |
+
while True:
|
542 |
+
if synced_gpus:
|
543 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
544 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
545 |
+
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
546 |
+
# send 0.0 if we finished, 1.0 otherwise
|
547 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
548 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
549 |
+
if this_peer_finished_flag.item() == 0.0:
|
550 |
+
break
|
551 |
+
|
552 |
+
# prepare model inputs
|
553 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
554 |
+
# forward pass to get next token
|
555 |
+
outputs = self(
|
556 |
+
**model_inputs,
|
557 |
+
return_dict=True,
|
558 |
+
)
|
559 |
+
|
560 |
+
if synced_gpus and this_peer_finished:
|
561 |
+
continue # don't waste resources running the code we don't need
|
562 |
+
|
563 |
+
next_token_logits = outputs.logits[:, -1, :]
|
564 |
+
# pre-process distribution
|
565 |
+
next_tokens_scores = logits_processor(input_ids, next_token_logits)
|
566 |
+
# argmax
|
567 |
+
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
|
568 |
+
# finished sentences should have their next token be a padding token
|
569 |
+
if eos_token_id is not None:
|
570 |
+
if pad_token_id is None:
|
571 |
+
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
572 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
573 |
+
# update generated ids, model inputs, and length for next step
|
574 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
575 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
576 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
577 |
+
)
|
578 |
+
yield input_ids
|
579 |
+
# if eos_token was found in one sentence, set sentence to finished
|
580 |
+
if eos_token_id_tensor is not None:
|
581 |
+
unfinished_sequences = unfinished_sequences.mul(
|
582 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
583 |
+
)
|
584 |
+
|
585 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
586 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
587 |
+
if not synced_gpus:
|
588 |
+
break
|
589 |
+
else:
|
590 |
+
this_peer_finished = True
|
591 |
+
yield input_ids
|
592 |
+
|
593 |
+
def stream_beam_search(
|
594 |
+
self,
|
595 |
+
generation_config,
|
596 |
+
input_ids,
|
597 |
+
logits_processor,
|
598 |
+
stopping_criteria,
|
599 |
+
synced_gpus,
|
600 |
+
**model_kwargs,
|
601 |
+
):
|
602 |
+
|
603 |
+
# 10. go into beam search generation modes
|
604 |
+
# 11. prepare beam search scorer
|
605 |
+
bos_token_id, eos_token_id, pad_token_id = (
|
606 |
+
generation_config.bos_token_id,
|
607 |
+
generation_config.eos_token_id,
|
608 |
+
generation_config.pad_token_id,
|
609 |
+
)
|
610 |
+
if isinstance(eos_token_id, int):
|
611 |
+
eos_token_id = [eos_token_id]
|
612 |
+
num_beams = generation_config.num_beams
|
613 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
614 |
+
beam_scorer = BeamSearchScorer(
|
615 |
+
batch_size=batch_size,
|
616 |
+
num_beams=generation_config.num_beams,
|
617 |
+
device=input_ids.device,
|
618 |
+
length_penalty=generation_config.length_penalty,
|
619 |
+
do_early_stopping=generation_config.early_stopping,
|
620 |
+
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
621 |
+
max_length=generation_config.max_length,
|
622 |
+
)
|
623 |
+
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
624 |
+
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
625 |
+
input_ids=input_ids,
|
626 |
+
expand_size=generation_config.num_beams,
|
627 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
628 |
+
**model_kwargs,
|
629 |
+
)
|
630 |
+
# beam_search logits
|
631 |
+
batch_beam_size, cur_len = input_ids.shape
|
632 |
+
if num_beams * batch_size != batch_beam_size:
|
633 |
+
raise ValueError(
|
634 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
635 |
+
)
|
636 |
+
beam_scores = torch.zeros(
|
637 |
+
(batch_size, num_beams), dtype=torch.float, device=input_ids.device
|
638 |
+
)
|
639 |
+
beam_scores[:, 1:] = -1e9
|
640 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
641 |
+
this_peer_finished = False # used by synced_gpus only
|
642 |
+
while True:
|
643 |
+
|
644 |
+
if synced_gpus:
|
645 |
+
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
646 |
+
# The following logic allows an early break if all peers finished generating their sequence
|
647 |
+
this_peer_finished_flag = torch.tensor(
|
648 |
+
0.0 if this_peer_finished else 1.0
|
649 |
+
).to(input_ids.device)
|
650 |
+
# send 0.0 if we finished, 1.0 otherwise
|
651 |
+
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
652 |
+
# did all peers finish? the reduced sum will be 0.0 then
|
653 |
+
if this_peer_finished_flag.item() == 0.0:
|
654 |
+
break
|
655 |
+
|
656 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
657 |
+
outputs = self(
|
658 |
+
**model_inputs,
|
659 |
+
return_dict=True,
|
660 |
+
output_attentions=False,
|
661 |
+
output_hidden_states=False,
|
662 |
+
)
|
663 |
+
|
664 |
+
if synced_gpus and this_peer_finished:
|
665 |
+
cur_len = cur_len + 1
|
666 |
+
continue # don't waste resources running the code we don't need
|
667 |
+
|
668 |
+
next_token_logits = outputs.logits[:, -1, :]
|
669 |
+
# next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) hack: adjust tokens for Marian.
|
670 |
+
next_token_scores = nn.functional.log_softmax(
|
671 |
+
next_token_logits, dim=-1
|
672 |
+
) # (batch_size * num_beams, vocab_size)
|
673 |
+
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
674 |
+
next_token_scores = next_token_scores_processed + beam_scores[
|
675 |
+
:, None
|
676 |
+
].expand_as(next_token_scores)
|
677 |
+
|
678 |
+
# reshape for beam search
|
679 |
+
vocab_size = next_token_scores.shape[-1]
|
680 |
+
next_token_scores = next_token_scores.view(
|
681 |
+
batch_size, num_beams * vocab_size
|
682 |
+
)
|
683 |
+
|
684 |
+
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
|
685 |
+
next_token_scores, next_tokens = torch.topk(
|
686 |
+
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
|
687 |
+
)
|
688 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
689 |
+
next_tokens = next_tokens % vocab_size
|
690 |
+
# stateless
|
691 |
+
beam_outputs = beam_scorer.process(
|
692 |
+
input_ids,
|
693 |
+
next_token_scores,
|
694 |
+
next_tokens,
|
695 |
+
next_indices,
|
696 |
+
pad_token_id=pad_token_id,
|
697 |
+
eos_token_id=eos_token_id,
|
698 |
+
beam_indices=None,
|
699 |
+
)
|
700 |
+
beam_scores = beam_outputs["next_beam_scores"]
|
701 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
702 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
703 |
+
|
704 |
+
input_ids = torch.cat(
|
705 |
+
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
|
706 |
+
)
|
707 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
708 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
709 |
+
)
|
710 |
+
if model_kwargs["past_key_values"] is not None:
|
711 |
+
model_kwargs["past_key_values"] = self._reorder_cache(
|
712 |
+
model_kwargs["past_key_values"], beam_idx
|
713 |
+
)
|
714 |
+
|
715 |
+
# increase cur_len
|
716 |
+
cur_len = cur_len + 1
|
717 |
+
|
718 |
+
yield input_ids
|
719 |
+
|
720 |
+
if beam_scorer.is_done or stopping_criteria(input_ids, None):
|
721 |
+
if not synced_gpus:
|
722 |
+
break
|
723 |
+
else:
|
724 |
+
this_peer_finished = True
|
725 |
+
|
726 |
+
final_result = beam_scorer.finalize(
|
727 |
+
input_ids,
|
728 |
+
beam_scores,
|
729 |
+
next_tokens,
|
730 |
+
next_indices,
|
731 |
+
pad_token_id=pad_token_id,
|
732 |
+
eos_token_id=eos_token_id,
|
733 |
+
max_length=stopping_criteria.max_length,
|
734 |
+
beam_indices=None,
|
735 |
+
)
|
736 |
+
yield final_result["sequences"]
|
737 |
+
|
738 |
+
class StreamLlamaForCausalLM(LlamaForCausalLM, StreamGenerationMixin):
|
739 |
+
pass
|
740 |
+
|
741 |
+
class StreamPeftGenerationMixin(PeftModelForCausalLM, StreamGenerationMixin):
|
742 |
+
|
743 |
+
# default it call `model = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[config.task_type](model, config)`, not cls!! so inherent PeftModelForCausalLM is non sense
|
744 |
+
@classmethod
|
745 |
+
def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs):
|
746 |
+
# work in peft==0.3.0
|
747 |
+
if peft.__version__ >= '0.3.0' and peft.__version__ != '0.3.0.dev0':
|
748 |
+
# load the config
|
749 |
+
from peft.utils import PromptLearningConfig
|
750 |
+
config = LoraConfig.from_pretrained(model_id)
|
751 |
+
|
752 |
+
if (getattr(model, "hf_device_map", None) is not None) and len(
|
753 |
+
set(model.hf_device_map.values()).intersection({"cpu", "disk"})
|
754 |
+
) > 0:
|
755 |
+
remove_hook_from_submodules(model)
|
756 |
+
|
757 |
+
if isinstance(config, PromptLearningConfig) and is_trainable:
|
758 |
+
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
|
759 |
+
else:
|
760 |
+
config.inference_mode = not is_trainable
|
761 |
+
|
762 |
+
# here is the hack
|
763 |
+
model = cls(model, config, adapter_name)
|
764 |
+
model.load_adapter(model_id, adapter_name, **kwargs)
|
765 |
+
# NOTICE
|
766 |
+
model.base_model_prepare_inputs_for_generation = model.base_model.prepare_inputs_for_generation
|
767 |
+
model._reorder_cache = model.base_model._reorder_cache
|
768 |
+
return model
|
769 |
+
else:
|
770 |
+
return cls.from_pretrained_old_peft_version(model, model_id, **kwargs)
|
771 |
+
|
772 |
+
|
773 |
+
@classmethod
|
774 |
+
def from_pretrained_old_peft_version(cls, model, model_id, **kwargs):
|
775 |
+
# work well in peft@e536616888d51b453ed354a6f1e243fecb02ea08
|
776 |
+
|
777 |
+
# load the config
|
778 |
+
config = LoraConfig.from_pretrained(model_id)
|
779 |
+
|
780 |
+
if getattr(model, "hf_device_map", None) is not None:
|
781 |
+
remove_hook_from_submodules(model)
|
782 |
+
|
783 |
+
# here is the hack
|
784 |
+
model = cls(model, config)
|
785 |
+
model._reorder_cache = model.base_model._reorder_cache
|
786 |
+
# load weights if any
|
787 |
+
if os.path.exists(os.path.join(model_id, "adapter_model.bin")):
|
788 |
+
filename = os.path.join(model_id, "adapter_model.bin")
|
789 |
+
else:
|
790 |
+
try:
|
791 |
+
filename = hf_hub_download(model_id, "adapter_model.bin")
|
792 |
+
except: # noqa
|
793 |
+
raise ValueError(
|
794 |
+
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
|
795 |
+
f"Please check that the file {'adapter_model.bin'} is present at {model_id}."
|
796 |
+
)
|
797 |
+
|
798 |
+
adapters_weights = torch.load(
|
799 |
+
filename,
|
800 |
+
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
801 |
+
)
|
802 |
+
# load the weights into the model
|
803 |
+
model = set_peft_model_state_dict(model, adapters_weights)
|
804 |
+
if getattr(model, "hf_device_map", None) is not None:
|
805 |
+
device_map = kwargs.get("device_map", "auto")
|
806 |
+
max_memory = kwargs.get("max_memory", None)
|
807 |
+
no_split_module_classes = model._no_split_modules
|
808 |
+
if device_map != "sequential":
|
809 |
+
max_memory = get_balanced_memory(
|
810 |
+
model,
|
811 |
+
max_memory=max_memory,
|
812 |
+
no_split_module_classes=no_split_module_classes,
|
813 |
+
low_zero=(device_map == "balanced_low_0"),
|
814 |
+
)
|
815 |
+
if isinstance(device_map, str):
|
816 |
+
device_map = infer_auto_device_map(
|
817 |
+
model,
|
818 |
+
max_memory=max_memory,
|
819 |
+
no_split_module_classes=no_split_module_classes,
|
820 |
+
)
|
821 |
+
model = dispatch_model(model, device_map=device_map)
|
822 |
+
hook = AlignDevicesHook(io_same_device=True)
|
823 |
+
if model.peft_config.peft_type == PeftType.LORA:
|
824 |
+
add_hook_to_module(model.base_model.model, hook)
|
825 |
+
else:
|
826 |
+
remove_hook_from_submodules(model.prompt_encoder)
|
827 |
+
add_hook_to_module(model.base_model, hook)
|
828 |
+
return model
|