pseudotensor commited on
Commit
efe0924
1 Parent(s): 2204f8f

Add application file and dependencies

Browse files
Files changed (9) hide show
  1. LICENSE +201 -0
  2. app.py +1513 -0
  3. client_test.py +121 -0
  4. finetune.py +930 -0
  5. h2o-logo.svg +1 -0
  6. prompter.py +106 -0
  7. requirements.txt +44 -0
  8. stopping.py +139 -0
  9. utils.py +39 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import sys
4
+ import os
5
+ import traceback
6
+ import typing
7
+
8
+ from utils import set_seed, flatten_list, clear_torch_cache
9
+
10
+ SEED = 1236
11
+ set_seed(SEED)
12
+
13
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
14
+ from typing import Union
15
+ import numpy as np
16
+ import pandas as pd
17
+
18
+ import fire
19
+ import torch
20
+ from peft import PeftModel
21
+ from transformers import GenerationConfig, StoppingCriteriaList, AutoModel
22
+ from accelerate import init_empty_weights, infer_auto_device_map
23
+
24
+ from prompter import Prompter
25
+
26
+ from finetune import get_loaders, example_data_points, generate_prompt, get_githash, prompt_types_strings, \
27
+ human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
28
+ from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
29
+
30
+
31
+ def main(
32
+ load_8bit: bool = False,
33
+ load_half: bool = True,
34
+ infer_devices: bool = True,
35
+ base_model: str = '',
36
+ tokenizer_base_model: str = '',
37
+ lora_weights: str = "",
38
+ force_1_gpu: bool = True,
39
+
40
+ prompt_type: Union[int, str] = None,
41
+ # input to generation
42
+ temperature: float = None,
43
+ top_p: float = None,
44
+ top_k: int = None,
45
+ num_beams: int = None,
46
+ repetition_penalty: float = None,
47
+ num_return_sequences: int = None,
48
+ do_sample: bool = None,
49
+ max_new_tokens: int = None,
50
+ min_new_tokens: int = None,
51
+ early_stopping: Union[bool, str] = None,
52
+ max_time: float = None,
53
+
54
+ llama_type: bool = None,
55
+ debug: bool = False,
56
+ share: bool = True,
57
+ local_files_only: bool = False,
58
+ resume_download: bool = True,
59
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
60
+
61
+ src_lang: str = "English",
62
+ tgt_lang: str = "Russian",
63
+
64
+ gradio: bool = True,
65
+ gradio_avoid_processing_markdown: bool = True,
66
+ chat: bool = True,
67
+ chat_history: int = 4096, # character length of chat context/history
68
+ stream_output: bool = True,
69
+ show_examples: bool = None,
70
+ verbose: bool = False,
71
+ h2ocolors: bool = True,
72
+ height: int = 400,
73
+ show_lora: bool = True,
74
+ # set to True to load --base_model after client logs in,
75
+ # to be able to free GPU memory when model is swapped
76
+ login_mode_if_model0: bool = False,
77
+
78
+ sanitize_user_prompt: bool = True,
79
+ sanitize_bot_response: bool = True,
80
+
81
+ extra_model_options: typing.List[str] = [],
82
+ extra_lora_options: typing.List[str] = [],
83
+
84
+ score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
85
+ auto_score: bool = True,
86
+
87
+ eval_sharegpt_prompts_only: int = 0,
88
+ eval_sharegpt_prompts_only_seed: int = 1234,
89
+ eval_sharegpt_as_output: bool = False,
90
+ ):
91
+ # allow set token directly
92
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
93
+ # override share if in spaces
94
+ if os.environ.get("HUGGINGFACE_SPACES"):
95
+ share = False
96
+ base_model = 'h2oai/h2ogpt-oasst1-512-12b'
97
+ load_8bit = True
98
+
99
+ # get defaults
100
+ model_lower = base_model.lower()
101
+ if not gradio:
102
+ # force, else not single response like want to look at
103
+ stream_output = False
104
+ # else prompt removal can mess up output
105
+ chat = False
106
+
107
+ placeholder_instruction, placeholder_input, \
108
+ stream_output, show_examples, \
109
+ prompt_type, temperature, top_p, top_k, num_beams, \
110
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
111
+ repetition_penalty, num_return_sequences, \
112
+ do_sample, \
113
+ src_lang, tgt_lang, \
114
+ examples, \
115
+ task_info = \
116
+ get_generate_params(model_lower, chat,
117
+ stream_output, show_examples,
118
+ prompt_type, temperature, top_p, top_k, num_beams,
119
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
120
+ repetition_penalty, num_return_sequences,
121
+ do_sample,
122
+ )
123
+
124
+ if not gradio:
125
+ if eval_sharegpt_prompts_only > 0:
126
+ # override default examples with shareGPT ones for human-level eval purposes only
127
+ filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
128
+ if not os.path.isfile(filename):
129
+ os.system('wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
130
+ import json
131
+ data = json.load(open(filename, 'rt'))
132
+ # focus on data that starts with human, else likely chopped from other data
133
+ turn_start = 0 # odd in general
134
+ data = [x for x in data if len(x['conversations']) > turn_start + 1 and
135
+ x['conversations'][turn_start]['from'] == 'human' and
136
+ x['conversations'][turn_start + 1]['from'] == 'gpt']
137
+ np.random.seed(eval_sharegpt_prompts_only_seed)
138
+ example1 = examples[-1] # pick reference example
139
+ examples = []
140
+ responses = []
141
+ for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
142
+ assert data[i]['conversations'][turn_start]['from'] == 'human'
143
+ instruction = data[i]['conversations'][turn_start]['value']
144
+ assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
145
+ output = data[i]['conversations'][turn_start + 1]['value']
146
+ examplenew = example1.copy()
147
+ examplenew[0] = instruction
148
+ examplenew[1] = '' # no input
149
+ examplenew[2] = '' # no context
150
+ examples.append(examplenew)
151
+ responses.append(output)
152
+
153
+ with torch.device("cuda"):
154
+ # ensure was set right above before examples generated
155
+ assert not stream_output, "stream_output=True does not make sense with example loop"
156
+ import time
157
+ from functools import partial
158
+
159
+ # get score model
160
+ smodel, stokenizer, sdevice = get_score_model(**locals())
161
+
162
+ if not eval_sharegpt_as_output:
163
+ model, tokenizer, device = get_model(**locals())
164
+ model_state = [model, tokenizer, device, base_model]
165
+ fun = partial(evaluate, model_state, debug=debug, chat=chat)
166
+ else:
167
+ assert eval_sharegpt_prompts_only > 0
168
+
169
+ def get_response(*args, exi=0):
170
+ # assumes same ordering of examples and responses
171
+ yield responses[exi]
172
+
173
+ fun = get_response
174
+ t0 = time.time()
175
+ score_dump = []
176
+ num_examples = len(examples)
177
+
178
+ import matplotlib.pyplot as plt
179
+
180
+ for exi, ex in enumerate(examples):
181
+ clear_torch_cache()
182
+ print("")
183
+ print("START" + "=" * 100)
184
+ print("Question: %s %s" % (ex[0], ('input=%s' % ex[1] if ex[1] else '')))
185
+ print("-" * 105)
186
+ # fun yields as generator, so have to iterate over it
187
+ # Also means likely do NOT want --stream_output=True, else would show all generations
188
+ for res in fun(*tuple(ex), exi=exi):
189
+ print(res)
190
+ if smodel:
191
+ score_with_prompt = False
192
+ if score_with_prompt:
193
+ data_point = dict(instruction=ex[0], input=ex[1])
194
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
195
+ prompt = prompter.generate_prompt(data_point)
196
+ else:
197
+ # just raw input and output
198
+ assert ex[1] in [None, ''] # should be no iinput
199
+ assert ex[2] in [None, ''] # should be no context
200
+ prompt = ex[0]
201
+ cutoff_len = 768 if os.environ.get("HUGGINGFACE_SPACES") else 2048
202
+ inputs = stokenizer(prompt, res,
203
+ return_tensors="pt",
204
+ truncation=True,
205
+ max_length=cutoff_len)
206
+ try:
207
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
208
+ except torch.cuda.OutOfMemoryError as e:
209
+ print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
210
+ traceback.print_exc()
211
+ score = 0.0
212
+ clear_torch_cache()
213
+ print("SCORE %s: %s" % (exi, score), flush=True)
214
+ score_dump.append(ex + [prompt, res, score])
215
+ # dump every score in case abort
216
+ scoring_path = 'scoring'
217
+ os.makedirs(scoring_path, exist_ok=True)
218
+ if eval_sharegpt_as_output:
219
+ used_base_model = 'gpt35'
220
+ used_lora_weights = ''
221
+ else:
222
+ used_base_model = str(base_model.split('/')[-1])
223
+ used_lora_weights = str(lora_weights.split('/')[-1])
224
+ df_scores = pd.DataFrame(score_dump, columns=eval_func_param_names + ['prompt', 'response', 'score'])
225
+ filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
226
+ eval_sharegpt_prompts_only_seed,
227
+ eval_sharegpt_as_output,
228
+ used_base_model,
229
+ used_lora_weights)
230
+ filename = os.path.join(scoring_path, filename)
231
+ df_scores.to_parquet(filename, index=False)
232
+ # plot histogram so far
233
+ plt.figure(figsize=(10, 10))
234
+ plt.hist(df_scores['score'], bins=20)
235
+ score_avg = np.mean(df_scores['score'])
236
+ score_median = np.median(df_scores['score'])
237
+ plt.title("Score avg: %s median: %s" % (score_avg, score_median))
238
+ plt.savefig(filename.replace('.parquet', '.png'))
239
+ plt.close()
240
+
241
+ print("END" + "=" * 102)
242
+ print("")
243
+ t2 = time.time()
244
+ print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
245
+ t1 = time.time()
246
+ print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
247
+ return
248
+ if gradio:
249
+ go_gradio(**locals())
250
+
251
+
252
+ def get_device():
253
+ if torch.cuda.is_available():
254
+ device = "cuda"
255
+ else:
256
+ raise RuntimeError("only cuda supported")
257
+
258
+ return device
259
+
260
+
261
+ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type, force_1_gpu=True, use_auth_token=False):
262
+ """
263
+ Ensure model gets on correct device
264
+ :param base_model:
265
+ :param model_loader:
266
+ :param load_half:
267
+ :param model_kwargs:
268
+ :param reward_type:
269
+ :return:
270
+ """
271
+ with init_empty_weights():
272
+ from transformers import AutoConfig
273
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
274
+ model = AutoModel.from_config(
275
+ config,
276
+ )
277
+
278
+ # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
279
+ # NOTE: Some models require avoiding sharding some layers,
280
+ # then would pass no_split_module_classes and give list of those layers.
281
+ device_map = infer_auto_device_map(
282
+ model,
283
+ dtype=torch.float16 if load_half else torch.float32,
284
+ )
285
+ if hasattr(model, 'model'):
286
+ device_map_model = infer_auto_device_map(
287
+ model.model,
288
+ dtype=torch.float16 if load_half else torch.float32,
289
+ )
290
+ device_map.update(device_map_model)
291
+ print('device_map: %s' % device_map, flush=True)
292
+
293
+ if force_1_gpu:
294
+ # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
295
+ # So avoid for now, just put on first GPU, unless score_model, put on last
296
+ n_gpus = torch.cuda.device_count()
297
+ if reward_type:
298
+ device_map = {'': n_gpus - 1}
299
+ else:
300
+ device_map = {'': 0}
301
+
302
+ load_in_8bit = model_kwargs.get('load_in_8bit', False)
303
+ model_kwargs['device_map'] = device_map
304
+
305
+ if load_in_8bit or not load_half:
306
+ model = model_loader.from_pretrained(
307
+ base_model,
308
+ **model_kwargs,
309
+ )
310
+ else:
311
+ model = model_loader.from_pretrained(
312
+ base_model,
313
+ **model_kwargs,
314
+ ).half()
315
+ return model
316
+
317
+
318
+ def get_model(
319
+ load_8bit: bool = False,
320
+ load_half: bool = True,
321
+ infer_devices: bool = True,
322
+ base_model: str = '',
323
+ tokenizer_base_model: str = '',
324
+ lora_weights: str = "",
325
+ force_1_gpu: bool = False,
326
+
327
+ llama_type: bool = None,
328
+ reward_type: bool = None,
329
+ local_files_only: bool = False,
330
+ resume_download: bool = True,
331
+ use_auth_token: Union[str, bool] = False,
332
+ compile: bool = True,
333
+ **kwargs,
334
+ ):
335
+ """
336
+
337
+ :param load_8bit: load model in 8-bit, not supported by all models
338
+ :param load_half: load model in 16-bit
339
+ :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
340
+ For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
341
+ So it is not the default
342
+ :param base_model: name/path of base model
343
+ :param tokenizer_base_model: name/path of tokenizer
344
+ :param lora_weights: name/path
345
+ :param force_1_gpu:
346
+ :param llama_type: whether LLaMa type model
347
+ :param reward_type: reward type model for sequence classification
348
+ :param local_files_only: use local files instead of from HF
349
+ :param resume_download: resume downloads from HF
350
+ :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
351
+ :parm compile: whether to compile torch model
352
+ :param kwargs:
353
+ :return:
354
+ """
355
+ print("Get %s model" % base_model, flush=True)
356
+ if lora_weights is not None and lora_weights.strip():
357
+ print("Get %s lora weights" % lora_weights, flush=True)
358
+ device = get_device()
359
+
360
+ if 'gpt2' in base_model.lower():
361
+ # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
362
+ load_8bit = False
363
+
364
+ assert base_model.strip(), (
365
+ "Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
366
+ )
367
+ llama_type = llama_type or "llama" in base_model
368
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
369
+ if not tokenizer_base_model:
370
+ tokenizer_base_model = base_model
371
+
372
+ if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
373
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
374
+ local_files_only=local_files_only,
375
+ resume_download=resume_download,
376
+ use_auth_token=use_auth_token,
377
+ )
378
+ else:
379
+ tokenizer = tokenizer_loader
380
+
381
+ if isinstance(tokenizer, str):
382
+ # already a pipeline, tokenizer_loader is string for task
383
+ model = model_loader(tokenizer,
384
+ model=base_model,
385
+ device=0 if device == "cuda" else -1,
386
+ torch_dtype=torch.float16)
387
+ else:
388
+ assert device == "cuda", "Unsupported device %s" % device
389
+ model_kwargs = dict(local_files_only=local_files_only,
390
+ torch_dtype=torch.float16,
391
+ resume_download=resume_download,
392
+ use_auth_token=use_auth_token)
393
+ if 'mbart-' not in base_model.lower():
394
+ model_kwargs.update(dict(load_in_8bit=load_8bit,
395
+ device_map={"": 0} if load_8bit else "auto",
396
+ ))
397
+ if 'OpenAssistant/reward-model'.lower() in base_model.lower():
398
+ # could put on other GPUs
399
+ model_kwargs['device_map'] = {"": 0}
400
+ model_kwargs.pop('torch_dtype', None)
401
+
402
+ if not lora_weights:
403
+ with torch.device("cuda"):
404
+ if infer_devices:
405
+ model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
406
+ force_1_gpu=force_1_gpu, use_auth_token=use_auth_token)
407
+ else:
408
+ if load_half and not load_8bit:
409
+ model = model_loader.from_pretrained(
410
+ base_model,
411
+ **model_kwargs).half()
412
+ else:
413
+ model = model_loader.from_pretrained(
414
+ base_model,
415
+ **model_kwargs)
416
+ elif load_8bit:
417
+ model = model_loader.from_pretrained(
418
+ base_model,
419
+ **model_kwargs
420
+ )
421
+ model = PeftModel.from_pretrained(
422
+ model,
423
+ lora_weights,
424
+ torch_dtype=torch.float16,
425
+ local_files_only=local_files_only,
426
+ resume_download=resume_download,
427
+ use_auth_token=use_auth_token,
428
+ device_map={"": 0}, # seems to be required
429
+ )
430
+ else:
431
+ with torch.device("cuda"):
432
+ model = model_loader.from_pretrained(
433
+ base_model,
434
+ **model_kwargs
435
+ )
436
+ model = PeftModel.from_pretrained(
437
+ model,
438
+ lora_weights,
439
+ torch_dtype=torch.float16,
440
+ local_files_only=local_files_only,
441
+ resume_download=resume_download,
442
+ use_auth_token=use_auth_token,
443
+ device_map="auto",
444
+ )
445
+ if load_half:
446
+ model.half()
447
+
448
+ # unwind broken decapoda-research config
449
+ if llama_type:
450
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
451
+ model.config.bos_token_id = 1
452
+ model.config.eos_token_id = 2
453
+ if 'gpt2' in base_model.lower():
454
+ # add special tokens that otherwise all share the same id
455
+ tokenizer.add_special_tokens({'bos_token': '<bos>',
456
+ 'eos_token': '<eos>',
457
+ 'pad_token': '<pad>'})
458
+
459
+ if not isinstance(tokenizer, str):
460
+ model.eval()
461
+ if torch.__version__ >= "2" and sys.platform != "win32" and compile:
462
+ model = torch.compile(model)
463
+
464
+ return model, tokenizer, device
465
+
466
+
467
+ def get_score_model(**kwargs):
468
+ # score model
469
+ if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
470
+ score_all_kwargs = kwargs.copy()
471
+ score_all_kwargs['load_8bit'] = False
472
+ score_all_kwargs['load_half'] = False
473
+ score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
474
+ score_all_kwargs['tokenizer_base_model'] = ''
475
+ score_all_kwargs['lora_weights'] = ''
476
+ score_all_kwargs['llama_type'] = False
477
+ score_all_kwargs['compile'] = False
478
+ smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
479
+ else:
480
+ smodel, stokenizer, sdevice = None, None, None
481
+ return smodel, stokenizer, sdevice
482
+
483
+
484
+ def go_gradio(**kwargs):
485
+
486
+ # get default model
487
+ all_kwargs = kwargs.copy()
488
+ all_kwargs.update(locals())
489
+ if kwargs.get('base_model') and not kwargs['login_mode_if_model0']:
490
+ model0, tokenizer0, device = get_model(**all_kwargs)
491
+ else:
492
+ # if empty model, then don't load anything, just get gradio up
493
+ model0, tokenizer0, device = None, None, None
494
+ model_state0 = [model0, tokenizer0, device, kwargs['base_model']]
495
+
496
+ # get score model
497
+ smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
498
+
499
+ if 'mbart-' in kwargs['model_lower']:
500
+ instruction_label = "Text to translate"
501
+ else:
502
+ instruction_label = "Instruction"
503
+ if kwargs['chat']:
504
+ instruction_label = "You (Shift-Enter or push Submit to send message)"
505
+
506
+ title = 'h2oGPT'
507
+ if kwargs['verbose']:
508
+ description = f"""Model {kwargs['base_model']} Instruct dataset.
509
+ For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).
510
+ Command: {str(' '.join(sys.argv))}
511
+ Hash: {get_githash()}
512
+ """
513
+ else:
514
+ description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
515
+ if os.environ.get("HUGGINGFACE_SPACES"):
516
+ description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
517
+ if kwargs['load_8bit']:
518
+ description += """<i><li> Model is loaded in 8-bit and 768 token context length to fit on HF GPUs, so model may perform worse than 16-bit with 2048 token limit.</i></li>"""
519
+ description += """<i><li>Model loading and unloading disabled on HF SPACES to avoid GPU OOM for multi-user environment.</i></li></ul></p>"""
520
+
521
+ if kwargs['verbose']:
522
+ task_info_md = f"""
523
+ ### Task: {kwargs['task_info']}"""
524
+ else:
525
+ task_info_md = ''
526
+
527
+ css_code = """footer {visibility: hidden}
528
+ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}}"""
529
+
530
+ from gradio.themes.utils import colors, fonts, sizes
531
+ if kwargs['h2ocolors']:
532
+ colors_dict = dict(primary_hue=colors.yellow,
533
+ secondary_hue=colors.yellow,
534
+ neutral_hue=colors.gray,
535
+ spacing_size=sizes.spacing_md,
536
+ radius_size=sizes.radius_md,
537
+ text_size=sizes.text_md,
538
+ )
539
+ else:
540
+ colors_dict = dict(primary_hue=colors.indigo,
541
+ secondary_hue=colors.indigo,
542
+ neutral_hue=colors.gray,
543
+ spacing_size=sizes.spacing_md,
544
+ radius_size=sizes.radius_md,
545
+ text_size=sizes.text_md,
546
+ )
547
+
548
+ import gradio as gr
549
+
550
+ if kwargs['gradio_avoid_processing_markdown']:
551
+ from gradio_client import utils as client_utils
552
+ from gradio.components import Chatbot
553
+
554
+ # gradio has issue with taking too long to process input/output for markdown etc.
555
+ # Avoid for now, allow raw html to render, good enough for chatbot.
556
+ def _postprocess_chat_messages(self, chat_message: str):
557
+ if chat_message is None:
558
+ return None
559
+ elif isinstance(chat_message, (tuple, list)):
560
+ filepath = chat_message[0]
561
+ mime_type = client_utils.get_mimetype(filepath)
562
+ filepath = self.make_temp_copy_if_needed(filepath)
563
+ return {
564
+ "name": filepath,
565
+ "mime_type": mime_type,
566
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
567
+ "data": None, # These last two fields are filled in by the frontend
568
+ "is_file": True,
569
+ }
570
+ elif isinstance(chat_message, str):
571
+ return chat_message
572
+ else:
573
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
574
+ Chatbot._postprocess_chat_messages = _postprocess_chat_messages
575
+
576
+ demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
577
+ callback = gr.CSVLogger()
578
+ # css_code = 'body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}'
579
+ # demo = gr.Blocks(theme='gstaff/xkcd', css=css_code)
580
+
581
+ model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
582
+ if kwargs['base_model'].strip() not in model_options:
583
+ lora_options = [kwargs['base_model'].strip()] + model_options
584
+ lora_options = kwargs['extra_lora_options']
585
+ if kwargs['lora_weights'].strip() not in lora_options:
586
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
587
+ # always add in no lora case
588
+ # add fake space so doesn't go away in gradio dropdown
589
+ lora_options = [' '] + kwargs['extra_lora_options']
590
+
591
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get('base_model') else 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
592
+
593
+ with demo:
594
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
595
+ # https://github.com/gradio-app/gradio/issues/3558
596
+ model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
597
+ model_options_state = gr.State([model_options])
598
+ lora_options_state = gr.State([lora_options])
599
+ gr.Markdown(
600
+ f"""
601
+ <h1 align="center"> {title}</h1>
602
+
603
+ {description}
604
+ {task_info_md}
605
+ """)
606
+
607
+ # go button visible if
608
+ base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
609
+ go_btn = gr.Button(value="LOGIN", visible=base_wanted, variant="primary")
610
+ normal_block = gr.Row(visible=not base_wanted)
611
+ with normal_block:
612
+ with gr.Tabs():
613
+ with gr.Row():
614
+ if not kwargs['chat']:
615
+ with gr.Column():
616
+ instruction = gr.Textbox(
617
+ lines=4, label=instruction_label,
618
+ placeholder=kwargs['placeholder_instruction'],
619
+ )
620
+ iinput = gr.Textbox(lines=4, label="Input",
621
+ placeholder=kwargs['placeholder_input'])
622
+ flag_btn = gr.Button("Flag")
623
+ if kwargs['score_model']:
624
+ if not kwargs['auto_score']:
625
+ with gr.Column():
626
+ score_btn = gr.Button("Score last prompt & response")
627
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
628
+ else:
629
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
630
+ with gr.Column():
631
+ if kwargs['chat']:
632
+ text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
633
+ with gr.Row():
634
+ with gr.Column(scale=50):
635
+ instruction = gr.Textbox(
636
+ lines=4, label=instruction_label,
637
+ placeholder=kwargs['placeholder_instruction'],
638
+ )
639
+ with gr.Row(): # .style(equal_height=False, equal_width=False):
640
+ submit = gr.Button(value='Submit').style(full_width=False, size='sm')
641
+ stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
642
+ with gr.Row():
643
+ clear = gr.Button("New Conversation")
644
+ flag_btn = gr.Button("Flag")
645
+ if kwargs['score_model']:
646
+ if not kwargs['auto_score']:
647
+ with gr.Column():
648
+ score_btn = gr.Button("Score last prompt & response").style(full_width=False, size='sm')
649
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
650
+ else:
651
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
652
+ retry = gr.Button("Regenerate")
653
+ undo = gr.Button("Undo")
654
+ else:
655
+ text_output = gr.Textbox(lines=5, label=output_label0)
656
+ with gr.TabItem("Input/Output"):
657
+ with gr.Row():
658
+ if 'mbart-' in kwargs['model_lower']:
659
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
660
+ value=kwargs['src_lang'],
661
+ label="Input Language")
662
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
663
+ value=kwargs['tgt_lang'],
664
+ label="Output Language")
665
+ with gr.TabItem("Expert"):
666
+ with gr.Row():
667
+ with gr.Column():
668
+ stream_output = gr.components.Checkbox(label="Stream output",
669
+ value=kwargs['stream_output'])
670
+ prompt_type = gr.Dropdown(prompt_types_strings,
671
+ value=kwargs['prompt_type'], label="Prompt Type")
672
+ temperature = gr.Slider(minimum=0, maximum=3,
673
+ value=kwargs['temperature'],
674
+ label="Temperature",
675
+ info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
676
+ top_p = gr.Slider(minimum=0, maximum=1,
677
+ value=kwargs['top_p'], label="Top p",
678
+ info="Cumulative probability of tokens to sample from")
679
+ top_k = gr.Slider(
680
+ minimum=0, maximum=100, step=1,
681
+ value=kwargs['top_k'], label="Top k",
682
+ info='Num. tokens to sample from'
683
+ )
684
+ num_beams = gr.Slider(minimum=1, maximum=8, step=1,
685
+ value=kwargs['num_beams'], label="Beams",
686
+ info="Number of searches for optimal overall probability. Uses more GPU memory/compute")
687
+ max_new_tokens = gr.Slider(
688
+ minimum=1, maximum=2048, step=1,
689
+ value=kwargs['max_new_tokens'], label="Max output length"
690
+ )
691
+ min_new_tokens = gr.Slider(
692
+ minimum=0, maximum=2048, step=1,
693
+ value=kwargs['min_new_tokens'], label="Min output length"
694
+ )
695
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
696
+ value=kwargs['early_stopping'])
697
+ max_time = gr.Slider(minimum=0, maximum=60 * 5, step=1,
698
+ value=kwargs['max_time'], label="Max. time",
699
+ info="Max. time to search optimal output.")
700
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
701
+ value=kwargs['repetition_penalty'],
702
+ label="Repetition Penalty")
703
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
704
+ value=kwargs['num_return_sequences'],
705
+ label="Number Returns", info="Must be <= num_beams")
706
+ do_sample = gr.Checkbox(label="Sample", info="Sample, for diverse output(s)",
707
+ value=kwargs['do_sample'])
708
+ if kwargs['chat']:
709
+ iinput = gr.Textbox(lines=4, label="Input",
710
+ placeholder=kwargs['placeholder_input'])
711
+ context = gr.Textbox(lines=1, label="Context",
712
+ info="Ignored in chat mode.") # nominally empty for chat mode
713
+
714
+ with gr.TabItem("Models"):
715
+ with gr.Row():
716
+ with gr.Column():
717
+ with gr.Row(scale=1):
718
+ with gr.Column(scale=50):
719
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", value=kwargs['base_model'])
720
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
721
+ with gr.Column(scale=1):
722
+ load_msg = "Load Model/LORA" if not os.environ.get("HUGGINGFACE_SPACES") \
723
+ else "LOAD DISABLED ON HF SPACES"
724
+ load_model_button = gr.Button(load_msg)
725
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
726
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], visible=kwargs['show_lora'])
727
+ with gr.Row(scale=1):
728
+ with gr.Column(scale=50):
729
+ new_model = gr.Textbox(label="New Model HF name/path")
730
+ new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
731
+ with gr.Column(scale=1):
732
+ add_model_button = gr.Button("Add new model name")
733
+ add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
734
+
735
+ inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
736
+ from functools import partial
737
+ all_kwargs = kwargs.copy()
738
+ all_kwargs.update(locals())
739
+ kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
740
+ fun = partial(evaluate,
741
+ **kwargs_evaluate)
742
+
743
+ dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
744
+ size="sm",
745
+ )
746
+ dark_mode_btn.click(
747
+ None,
748
+ None,
749
+ None,
750
+ _js="""() => {
751
+ if (document.querySelectorAll('.dark').length) {
752
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
753
+ } else {
754
+ document.querySelector('body').classList.add('dark');
755
+ }
756
+ }""",
757
+ api_name="dark",
758
+ )
759
+ if not kwargs['chat']:
760
+ submit = gr.Button("Submit")
761
+ submit_event = submit.click(fun, inputs=[model_state] + inputs_list, outputs=text_output, api_name='submit')
762
+
763
+ # examples after submit or any other buttons for chat or no chat
764
+ if kwargs['examples'] is not None and kwargs['show_examples']:
765
+ gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
766
+
767
+ # Score
768
+ def score_last_response(*args):
769
+ """ Similar to user() """
770
+ args_list = list(args)
771
+ history = args_list[-1]
772
+ if history is None:
773
+ print("Bad history in scoring last response, fix for now", flush=True)
774
+ history = []
775
+ if smodel is not None and \
776
+ stokenizer is not None and \
777
+ sdevice is not None and \
778
+ history is not None and len(history) > 0 and \
779
+ history[-1] is not None and \
780
+ len(history[-1]) >= 2:
781
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
782
+
783
+ max_length_tokenize = 512 if os.environ.get("HUGGINGFACE_SPACES") else 2048
784
+ cutoff_len = max_length_tokenize*4 # restrict deberta related to max for LLM
785
+
786
+ question = history[-1][0]
787
+ question = question[-cutoff_len:]
788
+
789
+ answer = history[-1][1]
790
+ answer = answer[-cutoff_len:]
791
+
792
+ inputs = stokenizer(question, answer,
793
+ return_tensors="pt",
794
+ truncation=True,
795
+ max_length=max_length_tokenize).to(smodel.device)
796
+ try:
797
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
798
+ except torch.cuda.OutOfMemoryError as e:
799
+ print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
800
+ del inputs
801
+ traceback.print_exc()
802
+ clear_torch_cache()
803
+ return 'Response Score: GPU OOM'
804
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
805
+ return 'Response Score: {:.1%}'.format(score)
806
+ else:
807
+ return 'Response Score: NA'
808
+
809
+ if kwargs['score_model']:
810
+ score_args = dict(fn=score_last_response,
811
+ inputs=inputs_list + [text_output],
812
+ outputs=[score_text],
813
+ )
814
+ if not kwargs['auto_score']:
815
+ score_event = score_btn.click(**score_args, queue=stream_output, api_name='score')
816
+
817
+ if kwargs['chat']:
818
+ def user(*args, undo=False, sanitize_user_prompt=True):
819
+ args_list = list(args)
820
+ user_message = args_list[0]
821
+ input1 = args_list[1]
822
+ context1 = args_list[2]
823
+ if input1 and not user_message.endswith(':'):
824
+ user_message1 = user_message + ":" + input1
825
+ elif input1:
826
+ user_message1 = user_message + input1
827
+ else:
828
+ user_message1 = user_message
829
+ if sanitize_user_prompt:
830
+ from better_profanity import profanity
831
+ user_message1 = profanity.censor(user_message1)
832
+
833
+ history = args_list[-1]
834
+ if undo and history:
835
+ history.pop()
836
+ args_list = args_list[:-1]
837
+ if history is None:
838
+ print("Bad history, fix for now", flush=True)
839
+ history = []
840
+ if undo:
841
+ return "", history
842
+ else:
843
+ return "", history + [[user_message1, None]]
844
+
845
+ def bot(*args, retry=False):
846
+ args_list = list(args)
847
+ history = args_list[-1]
848
+ if retry and history:
849
+ history.pop()
850
+ if not history:
851
+ print("No history", flush=True)
852
+ return
853
+ instruction1 = history[-1][0]
854
+ context1 = ''
855
+ if kwargs['chat_history'] > 0:
856
+ prompt_type1 = args_list[prompt_type_arg_id]
857
+ context1 = ''
858
+ for histi in range(len(history) - 1):
859
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
860
+ context1 += generate_prompt(data_point, prompt_type1, kwargs['chat'], reduced=True)[0].replace(
861
+ '<br>', '\n')
862
+ if not context1.endswith('\n'):
863
+ context1 += '\n'
864
+ if context1 and not context1.endswith('\n'):
865
+ context1 += '\n' # ensure if terminates abruptly, then human continues on next line
866
+ args_list[0] = instruction1
867
+ # only include desired chat history
868
+ args_list[2] = context1[-kwargs['chat_history']:]
869
+ model_state1 = args_list[-2]
870
+ args_list = args_list[:-2]
871
+ fun1 = partial(evaluate,
872
+ model_state1,
873
+ **kwargs_evaluate)
874
+ try:
875
+ for output in fun1(*tuple(args_list)):
876
+ bot_message = output
877
+ history[-1][1] = bot_message
878
+ yield history
879
+ except StopIteration:
880
+ yield history
881
+ except RuntimeError as e:
882
+ if "generator raised StopIteration" in str(e):
883
+ # assume last entry was bad, undo
884
+ history.pop()
885
+ yield history
886
+ raise
887
+ except Exception as e:
888
+ # put error into user input
889
+ history[-1][0] = "Exception: %s" % str(e)
890
+ yield history
891
+ raise
892
+ return
893
+
894
+ user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
895
+ inputs=inputs_list + [text_output],
896
+ outputs=[instruction, text_output],
897
+ )
898
+ bot_args = dict(fn=bot,
899
+ inputs=inputs_list + [model_state] + [text_output],
900
+ outputs=[text_output],
901
+ )
902
+ retry_bot_args = dict(fn=functools.partial(bot, retry=True),
903
+ inputs=inputs_list + [model_state] + [text_output],
904
+ outputs=[text_output],
905
+ )
906
+ undo_user_args = dict(fn=functools.partial(user, undo=True),
907
+ inputs=inputs_list + [text_output],
908
+ outputs=[instruction, text_output],
909
+ )
910
+
911
+ if kwargs['auto_score']:
912
+ submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
913
+ **bot_args, api_name='instruction_bot',
914
+ ).then(**score_args, api_name='instruction_bot_score').then(clear_torch_cache)
915
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
916
+ **bot_args, api_name='submit_bot',
917
+ ).then(**score_args, api_name='submit_bot_score').then(clear_torch_cache)
918
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
919
+ **retry_bot_args, api_name='retry_bot',
920
+ ).then(**score_args, api_name='retry_bot_score').then(clear_torch_cache)
921
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo').then(**score_args, api_name='undo_score')
922
+ else:
923
+ submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction').then(
924
+ **bot_args, api_name='instruction_bot',
925
+ ).then(clear_torch_cache)
926
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit').then(
927
+ **bot_args, api_name='submit_bot',
928
+ ).then(clear_torch_cache)
929
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry').then(
930
+ **retry_bot_args, api_name='retry_bot',
931
+ ).then(clear_torch_cache)
932
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo')
933
+ clear.click(lambda: None, None, text_output, queue=False, api_name='clear')
934
+
935
+ def load_model(model_name, lora_weights, model_state_old, prompt_type_old):
936
+ # ensure old model removed from GPU memory
937
+ if kwargs['debug']:
938
+ print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
939
+
940
+ if isinstance(model_state_old[0], str) and model0 is not None:
941
+ # best can do, move model loaded at first to CPU
942
+ model0.cpu()
943
+
944
+ if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
945
+ try:
946
+ model_state_old[0].cpu()
947
+ except Exception as e:
948
+ # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
949
+ print("Unable to put model on CPU: %s" % str(e), flush=True)
950
+ del model_state_old[0]
951
+ model_state_old[0] = None
952
+
953
+ if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
954
+ del model_state_old[1]
955
+ model_state_old[1] = None
956
+
957
+ clear_torch_cache()
958
+ if kwargs['debug']:
959
+ print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
960
+ all_kwargs['base_model'] = model_name.strip()
961
+ model_lower = model_name.strip().lower()
962
+ if model_lower in inv_prompt_type_to_model_lower:
963
+ prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
964
+ else:
965
+ prompt_type1 = prompt_type_old
966
+
967
+ all_kwargs['lora_weights'] = lora_weights.strip()
968
+ model1, tokenizer1, device1 = get_model(**all_kwargs)
969
+ clear_torch_cache()
970
+
971
+ if kwargs['debug']:
972
+ print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
973
+ return {model_state: [model1, tokenizer1, device1, model_name],
974
+ model_used: model_name,
975
+ lora_used: lora_weights,
976
+ prompt_type: prompt_type1}
977
+
978
+ def dropdown_prompt_type_list(x):
979
+ return gr.Dropdown.update(value=x)
980
+
981
+ def chatbot_list(x, model_used_in):
982
+ return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
983
+
984
+ load_model_args = dict(fn=load_model,
985
+ inputs=[model_choice, lora_choice, model_state, prompt_type],
986
+ outputs=[model_state, model_used, lora_used, prompt_type])
987
+ prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
988
+ chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
989
+ if not os.environ.get("HUGGINGFACE_SPACES"):
990
+ load_model_event = load_model_button.click(**load_model_args).then(**prompt_update_args).then(**chatbot_update_args).then(clear_torch_cache)
991
+
992
+ def dropdown_model_list(list0, x):
993
+ new_state = [list0[0] + [x]]
994
+ new_options = [*new_state[0]]
995
+ return gr.Dropdown.update(value=x, choices=new_options), '', new_state
996
+
997
+ add_model_event = add_model_button.click(fn=dropdown_model_list,
998
+ inputs=[model_options_state, new_model],
999
+ outputs=[model_choice, new_model, model_options_state])
1000
+
1001
+ def dropdown_lora_list(list0, x):
1002
+ new_state = [list0[0] + [x]]
1003
+ new_options = [*new_state[0]]
1004
+ return gr.Dropdown.update(value=x, choices=new_options), '', new_state
1005
+
1006
+ add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
1007
+ inputs=[lora_options_state, new_lora],
1008
+ outputs=[lora_choice, new_lora, lora_options_state])
1009
+
1010
+ go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
1011
+ .then(lambda: gr.update(visible=True), None, normal_block) \
1012
+ .then(**load_model_args).then(**prompt_update_args)
1013
+
1014
+ # callback for logging flagged input/output
1015
+ callback.setup(inputs_list + [text_output], "flagged_data_points")
1016
+ flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
1017
+ api_name='flag')
1018
+ if kwargs['chat']:
1019
+
1020
+ # don't pass text_output, don't want to clear output, just stop it
1021
+ # FIXME: have to click once to stop output and second time to stop GPUs going
1022
+ stop_btn.click(lambda: None, None, None, cancels=[submit_event, submit_event2, submit_event3],
1023
+ queue=False, api_name='stop').then(clear_torch_cache)
1024
+
1025
+ demo.queue(concurrency_count=1)
1026
+ favicon_path = "h2o-logo.svg"
1027
+ demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1028
+ favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
1029
+ print("Started GUI", flush=True)
1030
+ demo.block_thread()
1031
+
1032
+
1033
+ input_args_list = ['model_state']
1034
+ inputs_kwargs_list = ['debug', 'chat', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
1035
+
1036
+
1037
+ def get_inputs_list(inputs_dict, model_lower):
1038
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
1039
+ inputs_list = []
1040
+ for k in inputs_list_names:
1041
+ if k == 'kwargs':
1042
+ continue
1043
+ if k in input_args_list + inputs_kwargs_list:
1044
+ # these are added via partial, not taken as input
1045
+ continue
1046
+ if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
1047
+ continue
1048
+ inputs_list.append(inputs_dict[k])
1049
+ return inputs_list
1050
+
1051
+
1052
+ # index of prompt_type in evaluate function, after model_state
1053
+ prompt_type_arg_id = 4
1054
+
1055
+ eval_func_param_names = ['instruction',
1056
+ 'iinput',
1057
+ 'context',
1058
+ 'stream_output',
1059
+ 'prompt_type',
1060
+ 'temperature',
1061
+ 'top_p',
1062
+ 'top_k',
1063
+ 'num_beams',
1064
+ 'max_new_tokens',
1065
+ 'min_new_tokens',
1066
+ 'early_stopping',
1067
+ 'max_time',
1068
+ 'repetition_penalty',
1069
+ 'num_return_sequences',
1070
+ 'do_sample',
1071
+ ]
1072
+
1073
+
1074
+ def evaluate(
1075
+ model_state,
1076
+ # START NOTE: Examples must have same order of parameters
1077
+ instruction,
1078
+ iinput,
1079
+ context,
1080
+ stream_output,
1081
+ prompt_type,
1082
+ temperature,
1083
+ top_p,
1084
+ top_k,
1085
+ num_beams,
1086
+ max_new_tokens,
1087
+ min_new_tokens,
1088
+ early_stopping,
1089
+ max_time,
1090
+ repetition_penalty,
1091
+ num_return_sequences,
1092
+ do_sample,
1093
+ # END NOTE: Examples must have same order of parameters
1094
+ src_lang=None,
1095
+ tgt_lang=None,
1096
+ debug=False,
1097
+ chat=False,
1098
+ hard_stop_list=None,
1099
+ sanitize_bot_response=True,
1100
+ model_state0=None,
1101
+ **kwargs,
1102
+ ):
1103
+ if debug:
1104
+ locals_dict = locals().copy()
1105
+ locals_dict.pop('model_state', None)
1106
+ print(locals_dict)
1107
+
1108
+ no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
1109
+
1110
+ if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
1111
+ # try to free-up original model (i.e. list was passed as reference)
1112
+ if model_state0 is not None and model_state0[0] is not None:
1113
+ model_state0[0].cpu()
1114
+ model_state0[0] = None
1115
+ # try to free-up original tokenizer (i.e. list was passed as reference)
1116
+ if model_state0 is not None and model_state0[1] is not None:
1117
+ model_state0[1] = None
1118
+ clear_torch_cache()
1119
+ model, tokenizer, device, base_model = model_state
1120
+ elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
1121
+ assert isinstance(model_state[0], str)
1122
+ model, tokenizer, device, base_model = model_state0
1123
+ else:
1124
+ raise AssertionError(no_model_msg)
1125
+
1126
+ assert base_model.strip(), no_model_msg
1127
+ assert model, "Model is missing"
1128
+ assert tokenizer, "Tokenizer is missing"
1129
+
1130
+ data_point = dict(context=context, instruction=instruction, input=iinput)
1131
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
1132
+ prompt = prompter.generate_prompt(data_point)
1133
+
1134
+ if hard_stop_list is None:
1135
+ # acts like undo on user entry and bot response
1136
+ hard_stop_list = []
1137
+
1138
+ if isinstance(tokenizer, str):
1139
+ # pipeline
1140
+ if tokenizer == "summarization":
1141
+ key = 'summary_text'
1142
+ else:
1143
+ raise RuntimeError("No such task type %s" % tokenizer)
1144
+ # NOTE: uses max_length only
1145
+ yield model(prompt, max_length=max_new_tokens)[0][key]
1146
+
1147
+ if 'mbart-' in base_model.lower():
1148
+ assert src_lang is not None
1149
+ tokenizer.src_lang = languages_covered()[src_lang]
1150
+
1151
+ if chat:
1152
+ # override, ignore user change
1153
+ num_return_sequences = 1
1154
+ if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
1155
+ if prompt_type == 'human_bot':
1156
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
1157
+ # stopping only starts once output is beyond prompt
1158
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
1159
+ stop_words = [human, bot]
1160
+ encounters = [1, 2]
1161
+ elif prompt_type == 'instruct_vicuna':
1162
+ # even below is not enough, generic strings and many ways to encode
1163
+ stop_words = [
1164
+ '### Human:',
1165
+ """
1166
+ ### Human:""",
1167
+ """
1168
+ ### Human:
1169
+ """,
1170
+ '### Assistant:',
1171
+ """
1172
+ ### Assistant:""",
1173
+ """
1174
+ ### Assistant:
1175
+ """,
1176
+ ]
1177
+ encounters = [1, 2]
1178
+ else:
1179
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
1180
+ stop_words = ['### End']
1181
+ encounters = [1]
1182
+ stop_words_ids = [
1183
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
1184
+ # handle single token case
1185
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
1186
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
1187
+ # avoid padding in front of tokens
1188
+ if tokenizer.pad_token:
1189
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
1190
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
1191
+ else:
1192
+ stopping_criteria = StoppingCriteriaList()
1193
+
1194
+ # help to avoid errors like:
1195
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1196
+ # RuntimeError: expected scalar type Half but found Float
1197
+ # with - 256
1198
+ max_length_tokenize = 768 - 256 if os.environ.get("HUGGINGFACE_SPACES") else 2048 - 256
1199
+ cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1200
+ output_smallest = 30 * 4
1201
+ prompt = prompt[-cutoff_len - output_smallest:]
1202
+ inputs = tokenizer(prompt,
1203
+ return_tensors="pt",
1204
+ truncation=True,
1205
+ max_length=max_length_tokenize)
1206
+ if debug and len(inputs["input_ids"]) > 0:
1207
+ print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1208
+ input_ids = inputs["input_ids"].to(device)
1209
+ generation_config = GenerationConfig(
1210
+ temperature=float(temperature),
1211
+ top_p=float(top_p),
1212
+ top_k=top_k,
1213
+ num_beams=num_beams,
1214
+ do_sample=do_sample,
1215
+ repetition_penalty=float(repetition_penalty),
1216
+ num_return_sequences=num_return_sequences,
1217
+ renormalize_logits=True,
1218
+ remove_invalid_values=True,
1219
+ **kwargs,
1220
+ )
1221
+
1222
+ gen_kwargs = dict(input_ids=input_ids,
1223
+ generation_config=generation_config,
1224
+ return_dict_in_generate=True,
1225
+ output_scores=True,
1226
+ max_new_tokens=max_new_tokens, # prompt + new
1227
+ min_new_tokens=min_new_tokens, # prompt + new
1228
+ early_stopping=early_stopping, # False, True, "never"
1229
+ max_time=max_time,
1230
+ stopping_criteria=stopping_criteria,
1231
+ )
1232
+ if 'gpt2' in base_model.lower():
1233
+ gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
1234
+ elif 'mbart-' in base_model.lower():
1235
+ assert tgt_lang is not None
1236
+ tgt_lang = languages_covered()[tgt_lang]
1237
+ gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1238
+ else:
1239
+ gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
1240
+
1241
+ decoder = functools.partial(tokenizer.decode,
1242
+ skip_special_tokens=True,
1243
+ clean_up_tokenization_spaces=True,
1244
+ )
1245
+ decoder_raw = functools.partial(tokenizer.decode,
1246
+ skip_special_tokens=False,
1247
+ clean_up_tokenization_spaces=True,
1248
+ )
1249
+
1250
+ with torch.no_grad():
1251
+ # decoded tokenized prompt can deviate from prompt due to special characters
1252
+ inputs_decoded = decoder(input_ids[0])
1253
+ inputs_decoded_raw = decoder_raw(input_ids[0])
1254
+ if inputs_decoded == prompt:
1255
+ # normal
1256
+ pass
1257
+ elif inputs_decoded.lstrip() == prompt.lstrip():
1258
+ # sometimes extra space in front, make prompt same for prompt removal
1259
+ prompt = inputs_decoded
1260
+ elif inputs_decoded_raw == prompt:
1261
+ # some models specify special tokens that are part of normal prompt, so can't skip them
1262
+ inputs_decoded_raw = inputs_decoded
1263
+ decoder = decoder_raw
1264
+ else:
1265
+ print("WARNING: Special characters in prompt", flush=True)
1266
+ if stream_output:
1267
+ def generate(callback=None, **kwargs):
1268
+ # re-order stopping so Stream first and get out all chunks before stop for other reasons
1269
+ stopping_criteria0 = kwargs.get('stopping_criteria', StoppingCriteriaList()).copy()
1270
+ kwargs['stopping_criteria'] = StoppingCriteriaList()
1271
+ kwargs['stopping_criteria'].append(Stream(func=callback))
1272
+ for stopping_criteria1 in stopping_criteria0:
1273
+ kwargs['stopping_criteria'].append(stopping_criteria1)
1274
+
1275
+ try:
1276
+ model.generate(**kwargs)
1277
+ except torch.cuda.OutOfMemoryError as e:
1278
+ print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), flush=True)
1279
+ if kwargs['input_ids'] is not None:
1280
+ kwargs['input_ids'].cpu()
1281
+ kwargs['input_ids'] = None
1282
+ traceback.print_exc()
1283
+ clear_torch_cache()
1284
+ return
1285
+
1286
+ for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
1287
+ decoded_output = decoder(output)
1288
+ if output[-1] in [tokenizer.eos_token_id]:
1289
+ if debug:
1290
+ print("HIT EOS", flush=True)
1291
+ break
1292
+ if any(ele in decoded_output for ele in hard_stop_list):
1293
+ raise StopIteration
1294
+ yield prompter.get_response(decoded_output, prompt=inputs_decoded,
1295
+ sanitize_bot_response=sanitize_bot_response)
1296
+ return
1297
+ else:
1298
+ outputs = model.generate(**gen_kwargs)
1299
+ outputs = [decoder(s) for s in outputs.sequences]
1300
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
1301
+ sanitize_bot_response=sanitize_bot_response)
1302
+
1303
+
1304
+ def get_generate_params(model_lower, chat,
1305
+ stream_output, show_examples,
1306
+ prompt_type, temperature, top_p, top_k, num_beams,
1307
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
1308
+ repetition_penalty, num_return_sequences,
1309
+ do_sample):
1310
+ use_defaults = False
1311
+ use_default_examples = True
1312
+ examples = []
1313
+ task_info = f"{prompt_type}"
1314
+ if model_lower:
1315
+ print(f"Using Model {model_lower}", flush=True)
1316
+ else:
1317
+ print("No model defined yet", flush=True)
1318
+
1319
+ min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
1320
+ early_stopping = early_stopping if early_stopping is not None else False
1321
+ max_time_defaults = 60 * 3
1322
+ max_time = max_time if max_time is not None else max_time_defaults
1323
+
1324
+ if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1325
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
1326
+
1327
+ if show_examples is None:
1328
+ if chat:
1329
+ show_examples = False
1330
+ else:
1331
+ show_examples = True
1332
+
1333
+ summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
1334
+ Philipp: Sure you can use the new Hugging Face Deep Learning Container.
1335
+ Jeff: ok.
1336
+ Jeff: and how can I get started?
1337
+ Jeff: where can I find documentation?
1338
+ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
1339
+
1340
+ if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
1341
+ placeholder_instruction = summarize_example1
1342
+ placeholder_input = ""
1343
+ use_defaults = True
1344
+ use_default_examples = False
1345
+ examples += [
1346
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1347
+ 1.0, 1,
1348
+ False]]
1349
+ task_info = "Summarization"
1350
+ elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
1351
+ placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
1352
+ placeholder_input = ""
1353
+ use_defaults = True
1354
+ use_default_examples = True
1355
+ task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
1356
+ elif 'mbart-' in model_lower:
1357
+ placeholder_instruction = "The girl has long hair."
1358
+ placeholder_input = ""
1359
+ use_defaults = True
1360
+ use_default_examples = False
1361
+ examples += [
1362
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1363
+ 1.0, 1,
1364
+ False]]
1365
+ elif 'gpt2' in model_lower:
1366
+ placeholder_instruction = "The sky is"
1367
+ placeholder_input = ""
1368
+ prompt_type = prompt_type or 'plain'
1369
+ use_default_examples = True # some will be odd "continuations" but can be ok
1370
+ examples += [
1371
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
1372
+ 1.0, 1,
1373
+ False]]
1374
+ task_info = "Auto-complete phrase, code, etc."
1375
+ use_defaults = True
1376
+ else:
1377
+ if chat:
1378
+ placeholder_instruction = "Enter a question or imperative."
1379
+ else:
1380
+ placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1381
+ placeholder_input = ""
1382
+ if model_lower:
1383
+ prompt_type = prompt_type or 'human_bot'
1384
+ else:
1385
+ prompt_type = ''
1386
+ examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
1387
+ stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1, False]]
1388
+ task_info = "No task"
1389
+ if prompt_type == 'instruct':
1390
+ task_info = "Answer question or follow imperative as instruction with optionally input."
1391
+ elif prompt_type == 'plain':
1392
+ task_info = "Auto-complete phrase, code, etc."
1393
+ elif prompt_type == 'human_bot':
1394
+ if chat:
1395
+ task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
1396
+ else:
1397
+ task_info = "Ask question/imperative (input concatenated with instruction)"
1398
+
1399
+ # revert to plain if still nothing
1400
+ prompt_type = prompt_type or 'plain'
1401
+ if use_defaults:
1402
+ temperature = 1.0 if temperature is None else temperature
1403
+ top_p = 1.0 if top_p is None else top_p
1404
+ top_k = 40 if top_k is None else top_k
1405
+ num_beams = num_beams or 1
1406
+ max_new_tokens = max_new_tokens or 128
1407
+ repetition_penalty = repetition_penalty or 1.07
1408
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1409
+ do_sample = False if do_sample is None else do_sample
1410
+ else:
1411
+ temperature = 0.1 if temperature is None else temperature
1412
+ top_p = 0.75 if top_p is None else top_p
1413
+ top_k = 40 if top_k is None else top_k
1414
+ if chat:
1415
+ num_beams = num_beams or 1
1416
+ else:
1417
+ num_beams = num_beams or 4
1418
+ max_new_tokens = max_new_tokens or 256
1419
+ repetition_penalty = repetition_penalty or 1.07
1420
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1421
+ do_sample = False if do_sample is None else do_sample
1422
+ params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1423
+ early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
1424
+
1425
+ if use_default_examples:
1426
+ examples += [
1427
+ ["Translate English to French", "Good morning"] + params_list,
1428
+ ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
1429
+ ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
1430
+ [
1431
+ "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
1432
+ ''] + params_list,
1433
+ ['Translate to German: My name is Arthur', ''] + params_list,
1434
+ ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
1435
+ ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
1436
+ ''] + params_list,
1437
+ ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
1438
+ ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
1439
+ ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
1440
+ [
1441
+ "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
1442
+ ''] + params_list,
1443
+ ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
1444
+ [
1445
+ 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
1446
+ ''] + params_list,
1447
+ ["""def area_of_rectangle(a: float, b: float):
1448
+ \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
1449
+ ["""# a function in native python:
1450
+ def mean(a):
1451
+ return sum(a)/len(a)
1452
+
1453
+ # the same function using numpy:
1454
+ import numpy as np
1455
+ def mean(a):""", ''] + params_list,
1456
+ ["""X = np.random.randn(100, 100)
1457
+ y = np.random.randint(0, 1, 100)
1458
+
1459
+ # fit random forest classifier with 20 estimators""", ''] + params_list,
1460
+ ]
1461
+
1462
+ src_lang = "English"
1463
+ tgt_lang = "Russian"
1464
+
1465
+ return placeholder_instruction, placeholder_input, \
1466
+ stream_output, show_examples, \
1467
+ prompt_type, temperature, top_p, top_k, num_beams, \
1468
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
1469
+ repetition_penalty, num_return_sequences, \
1470
+ do_sample, \
1471
+ src_lang, tgt_lang, \
1472
+ examples, \
1473
+ task_info
1474
+
1475
+
1476
+ def languages_covered():
1477
+ # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
1478
+ covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
1479
+ covered = covered.split(', ')
1480
+ covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
1481
+ return covered
1482
+
1483
+
1484
+ def test_test_prompt(prompt_type='instruct', data_point=0):
1485
+ example_data_point = example_data_points[data_point]
1486
+ example_data_point.pop('output', None)
1487
+ return generate_prompt(example_data_point, prompt_type, False, False)
1488
+
1489
+
1490
+ if __name__ == "__main__":
1491
+ print("""
1492
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1493
+ python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1494
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
1495
+
1496
+ # generate without lora weights, no prompt
1497
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
1498
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
1499
+
1500
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
1501
+ # OpenChatKit settings:
1502
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
1503
+
1504
+ python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
1505
+ python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
1506
+ python generate.py --base_model='philschmid/bart-large-cnn-samsum'
1507
+ python generate.py --base_model='philschmid/flan-t5-base-samsum'
1508
+ python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
1509
+
1510
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
1511
+
1512
+ """, flush=True)
1513
+ fire.Fire(main)
client_test.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Client test. Simplest case is chat=False and stream_output=False
3
+
4
+ Run server with same choices:
5
+
6
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b --chat=False --stream_output=False
7
+
8
+ NOTE: For private models, add --use-auth_token=True
9
+
10
+ NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
11
+ Currently, this will force model to be on a single GPU.
12
+
13
+ Then run this client as:
14
+
15
+ python client_test.py
16
+ """
17
+
18
+ debug = False
19
+
20
+ import time
21
+ import os
22
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
23
+ from gradio_client import Client
24
+
25
+ client = Client("http://localhost:7860")
26
+ if debug:
27
+ print(client.view_api(all_endpoints=True))
28
+
29
+ instruction = "Who are you?"
30
+ iinput = ''
31
+ context = ''
32
+ # streaming output is supported, loops over and outputs each generation in streaming mode
33
+ # but leave stream_output=False for simple input/output mode
34
+ stream_output = False
35
+ prompt_type = 'human_bot'
36
+ temperature = 0.1
37
+ top_p = 0.75
38
+ top_k = 40
39
+ num_beams = 1
40
+ max_new_tokens = 500
41
+ min_new_tokens = 0
42
+ early_stopping = False
43
+ max_time = 180
44
+ repetition_penalty = 1.0
45
+ num_return_sequences = 1
46
+ do_sample = True
47
+
48
+ # CHOOSE: must match server
49
+ # NOTE chat mode works through files on gradio
50
+ # and client currently would have to work through those files
51
+ # in tmp, so not best for client. So default to False
52
+ chat = False
53
+
54
+
55
+ def test_client_basic():
56
+ args = [instruction,
57
+ iinput,
58
+ context,
59
+ stream_output,
60
+ prompt_type,
61
+ temperature,
62
+ top_p,
63
+ top_k,
64
+ num_beams,
65
+ max_new_tokens,
66
+ min_new_tokens,
67
+ early_stopping,
68
+ max_time,
69
+ repetition_penalty,
70
+ num_return_sequences,
71
+ do_sample]
72
+
73
+ if not chat:
74
+ # requires generate.py to run with --chat=False
75
+ api_name = '/submit'
76
+ res = client.predict(
77
+ *tuple(args),
78
+ api_name=api_name,
79
+ )
80
+ print(md_to_text(res))
81
+ else:
82
+ api_name = '/instruction'
83
+ import json
84
+ foofile = '/tmp/foo.json'
85
+ with open(foofile, 'wt') as f:
86
+ json.dump([['', None]], f)
87
+ args += [foofile]
88
+ if not stream_output:
89
+ for res in client.predict(
90
+ *tuple(args),
91
+ api_name=api_name,
92
+ ):
93
+ print(res)
94
+ res_file = client.predict(*tuple(args), api_name='/instruction_bot')
95
+ res = json.load(open(res_file, "rt"))[-1][-1]
96
+ print(md_to_text(res))
97
+ else:
98
+ print("streaming instruction_bot", flush=True)
99
+ job = client.submit(*tuple(args), api_name='/instruction_bot')
100
+ while not job.done():
101
+ outputs_list = job.communicator.job.outputs
102
+ if outputs_list:
103
+ res_file = job.communicator.job.outputs[-1]
104
+ res = json.load(open(res_file, "rt"))[-1][-1]
105
+ print(md_to_text(res))
106
+ time.sleep(0.1)
107
+ print(job.outputs())
108
+
109
+
110
+ import markdown # pip install markdown
111
+ from bs4 import BeautifulSoup # pip install beautifulsoup4
112
+
113
+
114
+ def md_to_text(md):
115
+ html = markdown.markdown(md)
116
+ soup = BeautifulSoup(html, features='html.parser')
117
+ return soup.get_text()
118
+
119
+
120
+ if __name__ == '__main__':
121
+ test_client_basic()
finetune.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import random
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from datetime import datetime
9
+ from typing import List, Union
10
+ import fire
11
+ import numpy as np
12
+ import torch
13
+ from datasets import load_dataset, concatenate_datasets
14
+ import transformers
15
+ import torch.distributed as dist
16
+
17
+ from peft import (
18
+ prepare_model_for_int8_training,
19
+ LoraConfig,
20
+ get_peft_model,
21
+ get_peft_model_state_dict,
22
+ set_peft_model_state_dict,
23
+ )
24
+
25
+ from peft import mapping
26
+ lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
27
+
28
+
29
+ def log(*args, **kwargs):
30
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
31
+ print(*args, **kwargs)
32
+
33
+
34
+ try:
35
+ import neptune
36
+ from transformers.integrations import NeptuneCallback
37
+
38
+ neptune_run = neptune.init_run(
39
+ source_files=[],
40
+ )
41
+ log("Connected to Neptune.")
42
+ except ImportError:
43
+ neptune_run = None
44
+ log("Please pip install neptune for tracking.")
45
+ except neptune.exceptions.NeptuneMissingApiTokenException:
46
+ neptune_run = None
47
+ os.environ["NEPTUNE_MODE"] = 'debug'
48
+ log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
49
+
50
+ from enum import Enum
51
+
52
+
53
+ class PromptType(Enum):
54
+ plain = 0
55
+ instruct = 1
56
+ quality = 2
57
+ human_bot = 3
58
+ dai_faq = 4
59
+ summarize = 5
60
+ simple_instruct = 6
61
+ instruct_vicuna = 7
62
+ instruct_with_end = 8
63
+ human_bot_orig = 9
64
+
65
+
66
+ prompt_type_to_model_name = {
67
+ 'plain': [
68
+ 'EleutherAI/gpt-j-6B',
69
+ 'EleutherAI/pythia-6.9b',
70
+ 'EleutherAI/pythia-12b',
71
+ 'EleutherAI/pythia-12b-deduped',
72
+ 'EleutherAI/gpt-neox-20b',
73
+ 'decapoda-research/llama-7b-hf',
74
+ 'decapoda-research/llama-13b-hf',
75
+ 'decapoda-research/llama-30b-hf',
76
+ 'facebook/mbart-large-50-many-to-many-mmt',
77
+ 'philschmid/bart-large-cnn-samsum',
78
+ 'philschmid/flan-t5-base-samsum',
79
+ 'gpt2',
80
+ 'distilgpt2',
81
+ ],
82
+ 'instruct': [],
83
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
84
+ 'quality': [],
85
+ 'human_bot': [
86
+ 'h2oai/h2ogpt-oig-oasst1-256-12b',
87
+ 'h2oai/h2ogpt-oasst1-512-12b',
88
+ 'h2oai/h2ogpt-oasst1-256-20b',
89
+ 'h2oai/h2ogpt-oasst1-512-20b',
90
+ 'h2oai/h2ogpt-oig-oasst1-256-6.9b',
91
+ ],
92
+ 'dai_faq': [],
93
+ 'summarize': [],
94
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
95
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
96
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
97
+ }
98
+
99
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
100
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
101
+
102
+ human = '<human>:'
103
+ bot = "<bot>:"
104
+
105
+ prompt_types_strings = []
106
+ for p in PromptType:
107
+ prompt_types_strings.extend([p.name])
108
+
109
+
110
+ prompt_types = []
111
+ for p in PromptType:
112
+ prompt_types.extend([p.name, p.value, str(p.value)])
113
+
114
+
115
+ # supported by huggingface evaluate
116
+ supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
117
+
118
+
119
+ def train(
120
+ save_code: bool = False,
121
+ run_id: int = None,
122
+
123
+ base_model: str = 'EleutherAI/gpt-neox-20b',
124
+ # base_model: str = 'EleutherAI/pythia-12b-deduped',
125
+ # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
126
+ # base_model: str = 'decapoda-research/llama-7b-hf',
127
+ # base_model: str = 'decapoda-research/llama-13b-hf',
128
+ # base_model: str = 'decapoda-research/llama-30b-hf',
129
+ # base_model: str = 'EleutherAI/gpt-j-6B',
130
+
131
+ # only needed if base_model is self-exported HF state without tokenizer
132
+ tokenizer_base_model: str = None,
133
+ # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
134
+
135
+ data_path: str = None,
136
+ data_col_dict: dict = None,
137
+ # data_path: str = "./dai_docs.train.json",
138
+ prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
139
+
140
+ valid_path: str = None,
141
+ # valid_path: str = "./dai_docs.valid.json",
142
+
143
+ # data_mix_in_path: str = "laion/OIG", # way too big, medium quality
144
+ data_mix_in_path: str = "0-hero/OIG-small-chip2", # high quality, 50 MB, good enough for now
145
+ data_mix_in_factor: float = 0.0, # >1: more mix-in data, <1: more of data_path data
146
+ data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
147
+ data_mix_in_prompt_type: str = "instruct", # just instruction->output, same as instruct
148
+
149
+ output_dir: str = None,
150
+
151
+ # LoRA checkpoint continuation
152
+ lora_weights: str = "",
153
+
154
+ # batching training hyperparams
155
+ batch_size: int = 128,
156
+ micro_batch_size: int = 4,
157
+ gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
158
+ fp16=True,
159
+
160
+ # general training hyperparams
161
+ num_epochs: float = 1,
162
+ learning_rate: float = 3e-4,
163
+
164
+ # validation settings
165
+ val_set_size: int = None,
166
+ val_metrics: List[str] = [],
167
+ eval_steps: int = None, # to control eval steps via steps
168
+ eval_epochs: float = None, # to control eval steps via epochs
169
+
170
+ # lora hyperparams
171
+ lora_r: int = 8,
172
+ lora_alpha: int = 16,
173
+ lora_dropout: float = 0.05,
174
+ lora_target_modules: List[str] = None,
175
+ llama_type: bool = None,
176
+
177
+ # llm hyperparams
178
+ train_on_inputs: bool = True, # if False, masks out inputs in loss
179
+ group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
180
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
181
+ cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
182
+
183
+ # torch training params
184
+ ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
185
+ local_files_only: bool = False, # else will download new versions, normally unwanted
186
+ resume_download: bool = True,
187
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
188
+ warmup_steps: int = 100,
189
+ logging_steps: int = 1,
190
+ save_steps: int = None, # must be round multiple of eval_steps
191
+ add_eos_token: bool = False,
192
+ ):
193
+ # allow set token directly
194
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
195
+
196
+ prompt_type = str(prompt_type) # migration from integers
197
+ assert prompt_type in prompt_types
198
+
199
+ world_size = int(os.getenv("WORLD_SIZE", 1))
200
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
201
+ rank = int(os.getenv("RANK", 0))
202
+ print(f"local_rank: {local_rank}")
203
+ print(f"global rank: {rank}")
204
+
205
+ gpus = max(world_size, torch.cuda.device_count())
206
+ run_id = run_id or 0
207
+ if not data_path:
208
+ raise ValueError("No data_path provided")
209
+ if not output_dir:
210
+ output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
211
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
212
+ raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
213
+ else:
214
+ if os.path.exists(output_dir) and not resume_from_checkpoint:
215
+ raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
216
+ device_map = "auto"
217
+
218
+ if save_code:
219
+ copy_code(run_id)
220
+ if tokenizer_base_model is None:
221
+ tokenizer_base_model = base_model
222
+ if llama_type is None:
223
+ llama_type = "llama" in base_model.lower()
224
+ assert (
225
+ base_model
226
+ ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
227
+ gradient_accumulation_steps = batch_size // micro_batch_size
228
+ assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
229
+
230
+ device_map = "auto"
231
+
232
+ locals_dict = locals()
233
+ locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
234
+ log(f"Training model with params:\n{locals_print}")
235
+ log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
236
+
237
+ max_memory = None
238
+ if gpus > 1:
239
+ if ddp:
240
+ log("Distributed: data parallel")
241
+ device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
242
+ gradient_accumulation_steps = gradient_accumulation_steps // world_size
243
+ else:
244
+ free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
245
+ max_memory = f"{free_in_GB - 2}GB"
246
+ max_memory = {i: max_memory for i in range(gpus)}
247
+ log("world_size: %d" % world_size)
248
+ log("num_gpus: %d" % gpus)
249
+ log("max mem: %s" % max_memory)
250
+
251
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
252
+
253
+ model = model_loader.from_pretrained(
254
+ base_model,
255
+ load_in_8bit=True,
256
+ device_map=device_map,
257
+ torch_dtype=torch.float16,
258
+ max_memory=max_memory,
259
+ local_files_only=local_files_only,
260
+ resume_download=resume_download,
261
+ use_auth_token=use_auth_token,
262
+ )
263
+ if gpus > 1:
264
+ if not ddp:
265
+ log("model parallel")
266
+ model.is_parallelizable = True
267
+ model.model_parallel = True
268
+
269
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
270
+ local_files_only=local_files_only,
271
+ resume_download=resume_download,
272
+ use_auth_token=use_auth_token)
273
+
274
+ tokenizer.pad_token_id = 0 # different from the eos token
275
+ # when generating, we will use the logits of right-most token to predict the next token
276
+ # so the padding should be on the left,
277
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
278
+ tokenizer.padding_side = "left" # Allow batched inference
279
+
280
+ def tokenize(prompt, add_eos_token=True):
281
+ # there's probably a way to do this with the tokenizer settings
282
+ # but again, gotta move fast
283
+ result = tokenizer(
284
+ prompt,
285
+ truncation=True,
286
+ max_length=cutoff_len,
287
+ padding=False,
288
+ return_tensors=None,
289
+ )
290
+ if (
291
+ result["input_ids"][-1] != tokenizer.eos_token_id
292
+ and len(result["input_ids"]) < cutoff_len
293
+ and add_eos_token
294
+ ):
295
+ result["input_ids"].append(tokenizer.eos_token_id)
296
+ result["attention_mask"].append(1)
297
+
298
+ result["labels"] = result["input_ids"].copy()
299
+
300
+ return result
301
+
302
+ def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
303
+ full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
304
+ tokenized_full_prompt = tokenize(full_prompt)
305
+ if not train_on_inputs:
306
+ user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
307
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
308
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
309
+ if add_eos:
310
+ user_prompt_len -= 1
311
+
312
+ # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
313
+ tokenized_full_prompt["labels"] = [
314
+ -100
315
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
316
+ user_prompt_len:
317
+ ] # could be sped up, probably
318
+ return tokenized_full_prompt
319
+
320
+ if "gpt-neox" not in base_model or True:
321
+ model = prepare_model_for_int8_training(model)
322
+ else:
323
+ model = prepare_model_for_int8_training(
324
+ model,
325
+ output_embedding_layer_name="embed_out", # keep output logits in float32
326
+ layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
327
+ )
328
+ if lora_weights:
329
+ from peft import PeftModel
330
+ model = PeftModel.from_pretrained(
331
+ model,
332
+ lora_weights,
333
+ torch_dtype=torch.float16,
334
+ device_map=device_map,
335
+ local_files_only=local_files_only,
336
+ resume_download=resume_download,
337
+ use_auth_token=use_auth_token,
338
+ )
339
+ else:
340
+ if lora_target_modules is None:
341
+ base_model_lower = base_model.lower()
342
+ if base_model_lower in lora_mappings:
343
+ lora_target_modules_cand = [lora_mappings[base_model_lower]]
344
+ else:
345
+ lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
346
+ else:
347
+ lora_target_modules_cand = [lora_target_modules]
348
+
349
+ for lora_target_modules in lora_target_modules_cand:
350
+ try:
351
+ config = LoraConfig(
352
+ r=lora_r,
353
+ lora_alpha=lora_alpha,
354
+ target_modules=lora_target_modules,
355
+ lora_dropout=lora_dropout,
356
+ bias="none",
357
+ task_type="CAUSAL_LM",
358
+ )
359
+ model = get_peft_model(model, config)
360
+ break
361
+ except ValueError as e:
362
+ if "Target modules" in str(e) and "not found" in str(e):
363
+ continue
364
+ else:
365
+ raise
366
+ from peft import PeftModel
367
+ assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
368
+ if resume_from_checkpoint:
369
+ # Check the available weights and load them
370
+ checkpoint_name = os.path.join(
371
+ resume_from_checkpoint, "pytorch_model.bin"
372
+ ) # Full checkpoint
373
+ if not os.path.exists(checkpoint_name):
374
+ checkpoint_name = os.path.join(
375
+ resume_from_checkpoint, "adapter_model.bin"
376
+ ) # only LoRA model - LoRA config above has to fit
377
+ resume_from_checkpoint = False # So the trainer won't try loading its state
378
+ # The two files above have a different name depending on how they were saved, but are actually the same.
379
+ if os.path.exists(checkpoint_name):
380
+ log(f"Restarting from {checkpoint_name}")
381
+ adapters_weights = torch.load(checkpoint_name)
382
+ model = set_peft_model_state_dict(model, adapters_weights)
383
+ else:
384
+ log(f"Checkpoint {checkpoint_name} not found")
385
+
386
+ print(model)
387
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
388
+
389
+ metrics = {}
390
+ for name in supported_metrics:
391
+ if name in val_metrics:
392
+ import evaluate # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
393
+ metrics[name] = evaluate.load(name)
394
+ log("Using Validation Metrics: %s" % str(list(metrics.keys())))
395
+ log("Supported Metrics: %s" % supported_metrics)
396
+
397
+ if val_set_size is None:
398
+ if len(metrics) == 0:
399
+ val_set_size = 1000
400
+ else:
401
+ val_set_size = 100
402
+ log("Auto set val_set_size %s" % val_set_size)
403
+ elif val_set_size < 1.0 and val_set_size != 0:
404
+ raise RuntimeError("Fractional validation size not supported.")
405
+
406
+ if valid_path:
407
+ data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
408
+ else:
409
+ if "json" in data_path:
410
+ data = load_dataset("json", data_files={"train": data_path})
411
+ else:
412
+ data = load_dataset(data_path)
413
+ data = data.rename_columns(data_col_dict or {})
414
+
415
+ valid_data = None
416
+ train_data_mix_in = None
417
+ valid_data_mix_in = None
418
+
419
+ if data_mix_in_path and data_mix_in_factor > 0:
420
+ # get mix-in training/validation data - to keep model "sane"
421
+ num_rows = data["train"].num_rows
422
+ log("Loading mix-in dataset: %s" % data_mix_in_path)
423
+ if "json" in data_mix_in_path:
424
+ data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
425
+ else:
426
+ data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
427
+ data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
428
+
429
+ # only get as much as we need to balance
430
+ valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
431
+ train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
432
+ mixin_small = data_mix_in.train_test_split(
433
+ test_size=train_size + valid_size,
434
+ shuffle=True, seed=np.random.randint(10000),
435
+ )["test"]
436
+ if valid_size:
437
+ mixin_train_test = mixin_small.train_test_split(
438
+ test_size=valid_size, shuffle=False,
439
+ )
440
+ train_data_mix_in = mixin_train_test["train"]
441
+ valid_data_mix_in = mixin_train_test["test"]
442
+ else:
443
+ train_data_mix_in = mixin_small
444
+
445
+ if "prompt_type" not in train_data_mix_in.column_names:
446
+ train_data_mix_in = train_data_mix_in.add_column(
447
+ "prompt_type",
448
+ [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
449
+ )
450
+ log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
451
+ if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
452
+ valid_data_mix_in = valid_data_mix_in.add_column(
453
+ "prompt_type",
454
+ [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
455
+ )
456
+ log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
457
+ log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
458
+
459
+ # get our own training/validation data - for fine-tuning
460
+ if val_set_size > 0 and not valid_path and not data_mix_in_path:
461
+ # create valid split from train
462
+ train_val = data["train"].train_test_split(
463
+ test_size=val_set_size, shuffle=True, seed=42
464
+ )
465
+ train_data = train_val["train"]
466
+ valid_data = train_val["test"]
467
+ else:
468
+ train_data = data["train"]
469
+ if valid_path:
470
+ # use given valid split, has priority over data_mix_in_path
471
+ valid_data = data["valid"]
472
+ if "prompt_type" not in train_data.column_names:
473
+ train_data = train_data.add_column(
474
+ "prompt_type",
475
+ [prompt_type] * train_data.num_rows,
476
+ )
477
+ log("Added prompt type %s to training data" % prompt_type)
478
+ if valid_data and "prompt_type" not in valid_data.column_names:
479
+ valid_data = valid_data.add_column(
480
+ "prompt_type",
481
+ [prompt_type] * valid_data.num_rows,
482
+ )
483
+ log("Added prompt type %s to validation data" % prompt_type)
484
+
485
+ assert train_data is not None
486
+
487
+ # shuffle and tokenize data
488
+ if train_data_mix_in:
489
+ train_data = concatenate_datasets([train_data, train_data_mix_in])
490
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
491
+ train_set_size = len(train_data)
492
+
493
+ if valid_data and valid_data_mix_in:
494
+ valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
495
+ elif valid_data_mix_in:
496
+ valid_data = valid_data_mix_in
497
+
498
+ if valid_data:
499
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
500
+ val_set_size = len(valid_data)
501
+ else:
502
+ val_set_size = 0
503
+ log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
504
+ sample_row_dict = train_data[:1]
505
+ del sample_row_dict['input_ids']
506
+ del sample_row_dict['attention_mask']
507
+ del sample_row_dict['labels']
508
+ log("Sample input: %s" % sample_row_dict)
509
+
510
+ if neptune_run:
511
+ neptune_callback = NeptuneCallback(run=neptune_run)
512
+ callbacks = [neptune_callback]
513
+ else:
514
+ from transformers.integrations import TensorBoardCallback, is_tensorboard_available
515
+ if is_tensorboard_available:
516
+ # tensorboard --logdir=runs/
517
+ from torch.utils.tensorboard import SummaryWriter
518
+ tb_writer = SummaryWriter()
519
+ callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
520
+ else:
521
+ callbacks = []
522
+
523
+ expected_steps = (train_set_size * num_epochs) // batch_size
524
+ if eval_steps is None and eval_epochs is None:
525
+ # 20 evaluations for a run
526
+ eval_steps = max(1, int(expected_steps / 20))
527
+ log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
528
+ elif eval_steps is None and eval_epochs is not None:
529
+ eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
530
+ log("Auto converted eval_epochs=%s to eval_steps %s"
531
+ " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
532
+ if save_steps is None:
533
+ save_steps = eval_steps
534
+ log("Auto step save_steps to %s" % save_steps)
535
+ elif save_steps > eval_steps:
536
+ # save steps must be round multiple of eval_steps
537
+ save_steps0 = save_steps
538
+ save_steps = max(1, (save_steps//eval_steps)) * eval_steps
539
+ if save_steps0 != save_steps:
540
+ log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
541
+
542
+ def compute_metrics(eval_preds):
543
+ # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
544
+ inputs = eval_preds.inputs
545
+ label_ids = eval_preds.label_ids
546
+ predictions = eval_preds.predictions
547
+
548
+ #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
549
+ #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
550
+ #decoded_inputs = [pred.strip() for pred in decoded_inputs]
551
+
552
+ label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
553
+ # tokenizer behavior like generate time
554
+ decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
555
+ clean_up_tokenization_spaces=True)
556
+ decoded_labels = [pred.strip() for pred in decoded_labels]
557
+
558
+ predictions = np.argmax(predictions, -1)
559
+ predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
560
+ # tokenizer behavior like generate time
561
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
562
+ clean_up_tokenization_spaces=True)
563
+ decoded_predictions = [pred.strip() for pred in decoded_predictions]
564
+
565
+ result = {}
566
+ for metric in metrics.values():
567
+ result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
568
+ # get rid of lists, for precision etc., for now
569
+ numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
570
+ result.update(numeric_results)
571
+ return result
572
+
573
+ # the callback that computes metrics of interest
574
+ if val_metrics:
575
+ trainer_kwargs = dict(compute_metrics=compute_metrics)
576
+ else:
577
+ trainer_kwargs = dict()
578
+
579
+ trainer = transformers.Trainer(
580
+ model=model,
581
+ tokenizer=tokenizer,
582
+ train_dataset=train_data,
583
+ eval_dataset=valid_data,
584
+ # NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
585
+ args=transformers.Seq2SeqTrainingArguments(
586
+ per_device_train_batch_size=micro_batch_size,
587
+ per_device_eval_batch_size=1,
588
+ eval_accumulation_steps=10,
589
+ # predict_with_generate=True, # SEQ2SEQ only
590
+ include_inputs_for_metrics=True,
591
+ gradient_accumulation_steps=gradient_accumulation_steps,
592
+ warmup_steps=warmup_steps,
593
+ num_train_epochs=num_epochs,
594
+ learning_rate=learning_rate,
595
+ gradient_checkpointing=gradient_checkpointing,
596
+ fp16=fp16,
597
+ # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
598
+ optim="adamw_torch", # consider "adafactor" to save memory
599
+ logging_steps=logging_steps,
600
+ logging_strategy="steps",
601
+ evaluation_strategy="steps" if val_set_size > 0 else "no",
602
+ save_strategy="steps",
603
+ eval_steps=eval_steps if val_set_size > 0 else None,
604
+ save_steps=save_steps,
605
+ output_dir=output_dir,
606
+ save_total_limit=3,
607
+ load_best_model_at_end=True if val_set_size > 0 else False,
608
+ ddp_find_unused_parameters=False if ddp else None,
609
+ group_by_length=group_by_length,
610
+ #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
611
+ #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
612
+ report_to='tensorboard' if not neptune_run else 'neptune',
613
+ ),
614
+ data_collator=transformers.DataCollatorForSeq2Seq(
615
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
616
+ ),
617
+ callbacks=callbacks,
618
+ **trainer_kwargs,
619
+ )
620
+ model.config.use_cache = False
621
+
622
+ old_state_dict = model.state_dict
623
+ model.state_dict = (
624
+ lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
625
+ ).__get__(model, type(model))
626
+
627
+ if torch.__version__ >= "2" and sys.platform != "win32":
628
+ model = torch.compile(model)
629
+ # WIP (not generally replacing layers until pytorch 2.1)
630
+ torch.backends.cuda.enable_flash_sdp(True)
631
+
632
+ if gpus > 1 and not ddp:
633
+ assert trainer.is_model_parallel
634
+ else:
635
+ assert not trainer.is_model_parallel
636
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
637
+
638
+ model.save_pretrained(output_dir)
639
+
640
+ log("\n If there's a warning about missing keys above, please disregard :)")
641
+
642
+
643
+ def get_loaders(llama_type, model_name, reward_type):
644
+ # NOTE: Some models need specific new prompt_type
645
+ # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
646
+ if llama_type:
647
+ from transformers import LlamaForCausalLM, LlamaTokenizer
648
+ model_loader = LlamaForCausalLM
649
+ tokenizer_loader = LlamaTokenizer
650
+ elif 'gpt2' in model_name.lower():
651
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
652
+ return GPT2LMHeadModel, GPT2Tokenizer
653
+ elif 'mbart-' in model_name.lower():
654
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
655
+ return MBartForConditionalGeneration, MBart50TokenizerFast
656
+ elif 't5' == model_name.lower() or \
657
+ 't5-' in model_name.lower() or \
658
+ 'flan-' in model_name.lower():
659
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
660
+ return T5ForConditionalGeneration, AutoTokenizer
661
+ elif 'bigbird' in model_name:
662
+ from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
663
+ return BigBirdPegasusForConditionalGeneration, AutoTokenizer
664
+ elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
665
+ from transformers import pipeline
666
+ return pipeline, "summarization"
667
+ elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
668
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
669
+ return AutoModelForSequenceClassification, AutoTokenizer
670
+ else:
671
+ from transformers import AutoTokenizer, AutoModelForCausalLM
672
+ model_loader = AutoModelForCausalLM
673
+ tokenizer_loader = AutoTokenizer
674
+ return model_loader, tokenizer_loader
675
+
676
+
677
+ def get_githash():
678
+ try:
679
+ githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
680
+ except:
681
+ githash = ''
682
+ return githash
683
+
684
+
685
+ def copy_code(run_id):
686
+ """
687
+ copy code to track changes
688
+ :param run_id:
689
+ :return:
690
+ """
691
+ rnd_num = str(random.randint(0, 2 ** 31))
692
+ run_id = 'run_' + str(run_id)
693
+ os.makedirs(run_id, exist_ok=True)
694
+ me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
695
+ me_file = os.path.basename(__file__)
696
+ new_me = os.path.join(run_id, me_file + '_' + get_githash())
697
+ if os.path.isfile(new_me):
698
+ new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
699
+ shutil.copy(me_full, new_me)
700
+ else:
701
+ shutil.copy(me_full, new_me)
702
+
703
+
704
+ def get_prompt(prompt_type, chat, context, reduced):
705
+ if prompt_type in [-1, "-1", "plain"]:
706
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
707
+ terminate_response = []
708
+ elif prompt_type == 'simple_instruct':
709
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
710
+ terminate_response = []
711
+ elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
712
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
713
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
714
+
715
+ PreInstruct = """
716
+ ### Instruction:
717
+ """
718
+
719
+ PreInput = """
720
+ ### Input:
721
+ """
722
+
723
+ PreResponse = """
724
+ ### Response:
725
+ """
726
+ if prompt_type in [7, "7", "instruct_with_end"]:
727
+ terminate_response = ['### End']
728
+ else:
729
+ terminate_response = None
730
+ elif prompt_type in [1, "1", "quality"]:
731
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
732
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
733
+
734
+ PreInstruct = """
735
+ ### Instruction:
736
+ """
737
+
738
+ PreInput = """
739
+ ### Input:
740
+ """
741
+
742
+ PreResponse = """
743
+ ### Response:
744
+ """
745
+ terminate_response = None
746
+ elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
747
+ if reduced or context or prompt_type in [2, "2", "human_bot"]:
748
+ preprompt = ''
749
+ else:
750
+ cur_date = time.strftime('%Y-%m-%d')
751
+ cur_time = time.strftime('%H:%M:%S %p %Z')
752
+
753
+ PRE_PROMPT = """\
754
+ Current Date: {}
755
+ Current Time: {}
756
+
757
+ """
758
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
759
+ start = human
760
+ promptB = promptA = '%s%s ' % (preprompt, start)
761
+
762
+ PreInstruct = ""
763
+
764
+ PreInput = None
765
+
766
+ PreResponse = bot
767
+
768
+ terminate_response = [start, PreResponse]
769
+ elif prompt_type in [3, "3", "dai_faq"]:
770
+ promptA = ''
771
+ promptB = 'Answer the following Driverless AI question.\n'
772
+
773
+ PreInstruct = """
774
+ ### Driverless AI frequently asked question:
775
+ """
776
+
777
+ PreInput = None
778
+
779
+ PreResponse = """
780
+ ### Driverless AI documentation answer:
781
+ """
782
+ terminate_response = ['\n\n']
783
+ elif prompt_type in [5, "5", "summarize"]:
784
+ promptA = promptB = PreInput = ''
785
+ PreInstruct = '## Main Text\n\n'
786
+ PreResponse = '\n\n## Summary\n\n'
787
+ terminate_response = None
788
+ elif prompt_type in [6, "6", "instruct_vicuna"]:
789
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
790
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
791
+
792
+ PreInstruct = """
793
+ ### Human:
794
+ """
795
+
796
+ PreInput = None
797
+
798
+ PreResponse = """
799
+ ### Assistant:
800
+ """
801
+ terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
802
+ else:
803
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
804
+
805
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
806
+
807
+
808
+ def generate_prompt(data_point, prompt_type, chat, reduced):
809
+ context = data_point.get('context') if chat else ''
810
+ if context is None:
811
+ context = ''
812
+ instruction = data_point.get('instruction')
813
+ input = data_point.get('input')
814
+ output = data_point.get('output')
815
+ prompt_type = data_point.get('prompt_type', prompt_type)
816
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
817
+ promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
818
+
819
+ prompt = context
820
+
821
+ if input and promptA:
822
+ prompt += f"""{promptA}"""
823
+ elif promptB:
824
+ prompt += f"""{promptB}"""
825
+
826
+ if instruction and PreInstruct is not None and input and PreInput is not None:
827
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
828
+ prompt = inject_newline(prompt_type, prompt)
829
+ elif instruction and input and PreInstruct is None and PreInput is not None:
830
+ prompt += f"""{PreInput}{instruction}
831
+ {input}"""
832
+ prompt = inject_newline(prompt_type, prompt)
833
+ elif input and instruction and PreInput is None and PreInstruct is not None:
834
+ prompt += f"""{PreInstruct}{instruction}
835
+ {input}"""
836
+ prompt = inject_newline(prompt_type, prompt)
837
+ elif instruction and PreInstruct is not None:
838
+ prompt += f"""{PreInstruct}{instruction}"""
839
+ prompt = inject_newline(prompt_type, prompt)
840
+ elif input and PreInput is not None:
841
+ prompt += f"""{PreInput}{input}"""
842
+ prompt = inject_newline(prompt_type, prompt)
843
+ elif input and instruction and PreInput is not None:
844
+ prompt += f"""{PreInput}{instruction}{input}"""
845
+ prompt = inject_newline(prompt_type, prompt)
846
+ elif input and instruction and PreInstruct is not None:
847
+ prompt += f"""{PreInstruct}{instruction}{input}"""
848
+ prompt = inject_newline(prompt_type, prompt)
849
+ elif input and instruction:
850
+ # i.e. for simple_instruct
851
+ prompt += f"""{instruction}: {input}"""
852
+ prompt = inject_newline(prompt_type, prompt)
853
+ elif input:
854
+ prompt += f"""{input}"""
855
+ prompt = inject_newline(prompt_type, prompt)
856
+ elif instruction:
857
+ prompt += f"""{instruction}"""
858
+ prompt = inject_newline(prompt_type, prompt)
859
+
860
+ if PreResponse is not None:
861
+ prompt += f"""{PreResponse}"""
862
+ pre_response = PreResponse # Don't use strip
863
+ else:
864
+ pre_response = ''
865
+
866
+ if output:
867
+ prompt += f"""{output}"""
868
+
869
+ return prompt, pre_response, terminate_response
870
+
871
+
872
+ def inject_newline(prompt_type, prompt):
873
+ if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
874
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
875
+ prompt += '\n'
876
+ return prompt
877
+
878
+
879
+ example_data_point0 = dict(instruction="Summarize",
880
+ input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
881
+ output="Ducks eat and swim at the lake.")
882
+
883
+ example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
884
+ output="Einstein.")
885
+
886
+ example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
887
+ output="Einstein.")
888
+
889
+ example_data_points = [example_data_point0, example_data_point1, example_data_point2]
890
+
891
+
892
+ def test_train_prompt(prompt_type='instruct', data_point=0):
893
+ example_data_point = example_data_points[data_point]
894
+ return generate_prompt(example_data_point, prompt_type, False, False)
895
+
896
+
897
+ def test_debug():
898
+ fire.Fire(train)
899
+
900
+
901
+ if __name__ == "__main__":
902
+ CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
903
+ CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
904
+ log(f"""
905
+ Example runs on 4 GPUs:
906
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
907
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
908
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
909
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
910
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
911
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512 &> 28.log
912
+
913
+ All metrics:
914
+ CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
915
+
916
+ # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
917
+ rippa>
918
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
919
+ ova>
920
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
921
+ timemachine>
922
+ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
923
+
924
+ """, flush=True)
925
+
926
+ if os.environ.get("LOCAL_RANK") is None:
927
+ # then not using torchrun, so can't do distributed, ensure CVD set
928
+ assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
929
+
930
+ fire.Fire(train)
h2o-logo.svg ADDED
prompter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from finetune import generate_prompt
2
+
3
+
4
+ class Prompter(object):
5
+ def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
6
+ allowed_repeat_line_length=10):
7
+ self.prompt_type = prompt_type
8
+ data_point = dict(instruction='', input='', output='')
9
+ _, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
10
+ self.debug = debug
11
+ self.chat = chat
12
+ self.stream_output = stream_output
13
+ self.repeat_penalty = repeat_penalty
14
+ self.allowed_repeat_line_length = allowed_repeat_line_length
15
+
16
+ def generate_prompt(self, data_point):
17
+ reduced = False
18
+ prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
19
+ if self.debug:
20
+ print("prompt: ", prompt, flush=True)
21
+ self.prompt = prompt
22
+ return prompt
23
+
24
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
25
+ if isinstance(outputs, str):
26
+ outputs = [outputs]
27
+ if self.debug:
28
+ print("output: ", '\n\n'.join(outputs), flush=True)
29
+ if prompt is not None:
30
+ self.prompt = prompt
31
+
32
+ def clean_response(response):
33
+ meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
34
+ for word in meaningless_words:
35
+ response = response.replace(word, "")
36
+ if sanitize_bot_response:
37
+ from better_profanity import profanity
38
+ response = profanity.censor(response)
39
+ response = response.strip("\n")
40
+ return response
41
+
42
+ def clean_repeats(response):
43
+ lines = response.split('\n')
44
+ new_lines = []
45
+ [new_lines.append(line) for line in lines if
46
+ line not in new_lines or len(line) < self.allowed_repeat_line_length]
47
+ if self.debug and len(lines) != len(new_lines):
48
+ print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
49
+ response = '\n'.join(new_lines)
50
+ return response
51
+
52
+ multi_output = len(outputs) > 1
53
+
54
+ for oi, output in enumerate(outputs):
55
+ if self.prompt_type in [0, '0', 'plain']:
56
+ output = clean_response(output)
57
+ else:
58
+ # find first instance of prereponse
59
+ # prompt sometimes has odd characters, that mutate length,
60
+ # so can't go by length alone
61
+ if self.pre_response:
62
+ outputi = output.find(prompt)
63
+ if outputi >= 0:
64
+ output = output[outputi + len(prompt):]
65
+ allow_terminate = True
66
+ else:
67
+ # subtraction is risky due to space offsets sometimes, so only do if necessary
68
+ output = output[len(prompt) - len(self.pre_response):]
69
+ # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
70
+ if self.pre_response in output:
71
+ output = output.split(self.pre_response)[1]
72
+ allow_terminate = True
73
+ else:
74
+ print("Failure of parsing: %s" % output, flush=True)
75
+ allow_terminate = False
76
+ else:
77
+ allow_terminate = True
78
+ output = output[len(prompt):]
79
+ # clean after subtract prompt out, so correct removal of pre_response
80
+ output = clean_response(output).strip()
81
+ if self.repeat_penalty:
82
+ output = clean_repeats(output).strip()
83
+ if self.terminate_response and allow_terminate:
84
+ finds = []
85
+ for term in self.terminate_response:
86
+ finds.append(output.find(term))
87
+ finds = [x for x in finds if x >= 0]
88
+ if len(finds) > 0:
89
+ termi = finds[0]
90
+ output = output[:termi].strip()
91
+ else:
92
+ output = output.strip()
93
+ else:
94
+ output = output.strip()
95
+ if multi_output:
96
+ # prefix with output counter
97
+ output = "\n=========== Output %d\n\n" % (1 + oi) + output
98
+ if oi > 0:
99
+ # post fix outputs with seperator
100
+ output += '\n'
101
+ outputs[oi] = output
102
+ # join all outputs, only one extra new line between outputs
103
+ output = '\n'.join(outputs)
104
+ if self.debug:
105
+ print("outputclean: ", '\n\n'.join(outputs), flush=True)
106
+ return output
requirements.txt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for generate (gradio server) and finetune
2
+ datasets==2.10.1
3
+ sentencepiece==0.1.97
4
+ accelerate==0.18.0
5
+ gradio==3.27.0
6
+ huggingface_hub==0.13.4
7
+ appdirs==1.4.4
8
+ fire==0.5.0
9
+ docutils==0.19
10
+ torch==2.0.0
11
+ evaluate==0.4.0
12
+ rouge_score==0.1.2
13
+ sacrebleu==2.3.1
14
+ scikit-learn==1.2.2
15
+ alt-profanity-check==1.2.2
16
+ better-profanity==0.6.1
17
+ numpy==1.24.2
18
+ pandas==1.5.3
19
+ matplotlib==3.7.1
20
+ loralib==0.1.1
21
+ bitsandbytes==0.38.1
22
+ git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
23
+ transformers==4.28.1
24
+ tokenizers==0.13.3
25
+
26
+ # optional for finetune
27
+ tensorboard==2.12.1
28
+ neptune==1.1.1
29
+
30
+ # for gradio client
31
+ gradio_client==0.1.3
32
+ beautifulsoup4==4.12.2
33
+ markdown==3.4.1
34
+
35
+ # data and testing
36
+ pytest==7.2.2
37
+ pytest-xdist==3.2.1
38
+ nltk==3.8.1
39
+ textstat==0.7.3
40
+ pandoc==2.3
41
+ pypandoc==1.11
42
+ openpyxl==3.1.2
43
+ lm_dataformat==0.0.20
44
+ bioc==2.0
stopping.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from queue import Queue
3
+ from threading import Thread
4
+ import collections.abc
5
+
6
+ import torch
7
+ from transformers import StoppingCriteria
8
+
9
+
10
+ class StoppingCriteriaSub(StoppingCriteria):
11
+
12
+ def __init__(self, stops=[], encounters=[]):
13
+ super().__init__()
14
+ assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
15
+ self.encounters = encounters
16
+ self.stops = [stop.to("cuda") for stop in stops]
17
+ self.num_stops = [0] * len(stops)
18
+
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ for stopi, stop in enumerate(self.stops):
21
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
22
+ self.num_stops[stopi] += 1
23
+ if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
24
+ return True
25
+ # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
+ # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
+ return False
28
+
29
+
30
+ class Stream(StoppingCriteria):
31
+ """
32
+ This class can be used to callback during generation. Keep
33
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
34
+
35
+ Args:
36
+ func (`callable`):
37
+ A callable function to apply on first input in list every iteration of generation
38
+ """
39
+
40
+ def __init__(self, func=None):
41
+ self.func = func
42
+
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ if self.func is not None:
45
+ # only consume first of multiple responses
46
+ self.func(input_ids[0])
47
+ return False
48
+
49
+
50
+ class CallbackToGenerator(collections.abc.Generator):
51
+ """
52
+ A generator wrapper for a function that invokes a callback multiple times.
53
+
54
+ Calling `send` on the generator emits a value from one callback, and returns
55
+ the next.
56
+
57
+ Note this starts a background thread
58
+ """
59
+
60
+ def __init__(self, func, *args, callback=None, **kwargs):
61
+ self.func = func
62
+ self.args = args
63
+ self.kwargs = kwargs
64
+ self.callback = callback
65
+
66
+ self._ready_queue = Queue(1)
67
+ self._done_queue = Queue(1)
68
+ self._done_holder = [False]
69
+
70
+ # local to avoid reference cycles
71
+ ready_queue = self._ready_queue
72
+ done_queue = self._done_queue
73
+ done_holder = self._done_holder
74
+
75
+ def val_callback(value):
76
+ done_queue.put((False, value))
77
+ cmd, val = ready_queue.get()
78
+ if cmd == 'send':
79
+ return val
80
+ elif cmd == 'throw':
81
+ raise val
82
+ else:
83
+ assert False # pragma: no cover
84
+
85
+ def thread_func():
86
+ while True:
87
+ cmd, val = ready_queue.get()
88
+ if cmd == 'send' and val is not None:
89
+ done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
90
+ continue
91
+ break
92
+ try:
93
+ if cmd == 'throw':
94
+ raise val
95
+ ret = func(callback=val_callback, **self.kwargs)
96
+ raise StopIteration(ret) if ret is not None else StopIteration
97
+ except BaseException as e:
98
+ done_holder[0] = True
99
+ done_queue.put((True, e))
100
+
101
+ self._thread = Thread(target=thread_func)
102
+ self._thread.start()
103
+
104
+ def _put(self, *args):
105
+ if self._done_holder[0]:
106
+ raise StopIteration
107
+ self._ready_queue.put(args)
108
+ is_exception, val = self._done_queue.get()
109
+ if is_exception:
110
+ try:
111
+ raise val
112
+ finally:
113
+ # prevent val's traceback containing a reference cycle
114
+ del val
115
+ else:
116
+ return val
117
+
118
+ def send(self, value):
119
+ return self._put('send', value)
120
+
121
+ def throw(self, exc):
122
+ return self._put('throw', exc)
123
+
124
+ def close(self):
125
+ try:
126
+ self.throw(GeneratorExit)
127
+ except StopIteration:
128
+ self._thread.join()
129
+ except GeneratorExit:
130
+ self._thread.join()
131
+ except BaseException:
132
+ self._thread.join()
133
+ raise
134
+ else:
135
+ # yielded again, can't clean up the thread
136
+ raise RuntimeError('Task with callback ignored GeneratorExit')
137
+
138
+ def __del__(self):
139
+ self.close()
utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def set_seed(seed: int):
9
+ """
10
+ Sets the seed of the entire notebook so results are the same every time we run.
11
+ This is for REPRODUCIBILITY.
12
+ """
13
+ np.random.seed(seed)
14
+ random_state = np.random.RandomState(seed)
15
+ random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+ os.environ['PYTHONHASHSEED'] = str(seed)
21
+ return random_state
22
+
23
+
24
+ def flatten_list(lis):
25
+ """Given a list, possibly nested to any level, return it flattened."""
26
+ new_lis = []
27
+ for item in lis:
28
+ if type(item) == type([]):
29
+ new_lis.extend(flatten_list(item))
30
+ else:
31
+ new_lis.append(item)
32
+ return new_lis
33
+
34
+
35
+ def clear_torch_cache():
36
+ if torch.cuda.is_available:
37
+ torch.cuda.empty_cache()
38
+ torch.cuda.ipc_collect()
39
+ gc.collect()