edwardjiang commited on
Commit
f086839
1 Parent(s): f984eee

Update launch.py

Browse files
Files changed (1) hide show
  1. launch.py +319 -1
launch.py CHANGED
@@ -1 +1,319 @@
1
- import module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import torch
4
+ import transformers
5
+ from distutils.util import strtobool
6
+ from tokenizers import pre_tokenizers
7
+
8
+ from transformers.generation.utils import logger
9
+ import mdtex2html
10
+ import warnings
11
+
12
+
13
+ logger.setLevel("ERROR")
14
+ warnings.filterwarnings("ignore")
15
+
16
+
17
+ warnings.filterwarnings("ignore")
18
+
19
+
20
+ def _strtobool(x):
21
+ return bool(strtobool(x))
22
+
23
+
24
+ QA_SPECIAL_TOKENS = {
25
+ "Question": "<|prompter|>",
26
+ "Answer": "<|assistant|>",
27
+ "System": "<|system|>",
28
+ "StartPrefix": "<|prefix_begin|>",
29
+ "EndPrefix": "<|prefix_end|>",
30
+ "InnerThought": "<|inner_thoughts|>",
31
+ "EndOfThought": "<eot>"
32
+ }
33
+
34
+
35
+ def format_pairs(pairs, eos_token, add_initial_reply_token=False):
36
+ conversations = [
37
+ "{}{}{}".format(
38
+ QA_SPECIAL_TOKENS["Question" if i % 2 == 0 else "Answer"], pairs[i], eos_token)
39
+ for i in range(len(pairs))
40
+ ]
41
+ if add_initial_reply_token:
42
+ conversations.append(QA_SPECIAL_TOKENS["Answer"])
43
+ return conversations
44
+
45
+
46
+ def format_system_prefix(prefix, eos_token):
47
+ return "{}{}{}".format(
48
+ QA_SPECIAL_TOKENS["System"],
49
+ prefix,
50
+ eos_token,
51
+ )
52
+
53
+
54
+ def get_specific_model(
55
+ model_name, seq2seqmodel=False, without_head=False, cache_dir=".cache", quantization=False, **kwargs
56
+ ):
57
+ # encoder-decoder support for Flan-T5 like models
58
+ # for now, we can use an argument but in the future,
59
+ # we can automate this
60
+
61
+ model = transformers.LlamaForCausalLM.from_pretrained(model_name, **kwargs)
62
+
63
+ return model
64
+
65
+
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("--model_path", type=str, required=True)
68
+ parser.add_argument("--max_new_tokens", type=int, default=200)
69
+ parser.add_argument("--top_k", type=int, default=40)
70
+ parser.add_argument("--do_sample", type=_strtobool, default=True)
71
+ # parser.add_argument("--system_prefix", type=str, default=None)
72
+ parser.add_argument("--per-digit-tokens", action="store_true")
73
+
74
+
75
+ args = parser.parse_args()
76
+
77
+ # # 开放问答
78
+ # system_prefix = \
79
+ # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
80
+ # - EduChat是一个由华东师范大学开发的对话式语言模型。
81
+ # EduChat的工具
82
+ # - Web search: Disable.
83
+ # - Calculators: Disable.
84
+ # EduChat的能力
85
+ # - Inner Thought: Disable.
86
+ # 对话主题
87
+ # - General: Enable.
88
+ # - Psychology: Disable.
89
+ # - Socrates: Disable.'''"</s>"
90
+
91
+ # # 启发式教学
92
+ # system_prefix = \
93
+ # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
94
+ # - EduChat是一个由华东师范大学开发的对话式语言模型。
95
+ # EduChat的工具
96
+ # - Web search: Disable.
97
+ # - Calculators: Disable.
98
+ # EduChat的能力
99
+ # - Inner Thought: Disable.
100
+ # 对话主题
101
+ # - General: Disable.
102
+ # - Psychology: Disable.
103
+ # - Socrates: Enable.'''"</s>"
104
+
105
+ # 情感支持
106
+ system_prefix = \
107
+ "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
108
+ - EduChat是一个由华东师范大学开发的对话式语言模型。
109
+ EduChat的工具
110
+ - Web search: Disable.
111
+ - Calculators: Disable.
112
+ EduChat的能力
113
+ - Inner Thought: Disable.
114
+ 对话主题
115
+ - General: Disable.
116
+ - Psychology: Enable.
117
+ - Socrates: Disable.'''"</s>"
118
+
119
+ # # 情感支持(with InnerThought)
120
+ # system_prefix = \
121
+ # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
122
+ # - EduChat是一个由华东师范大学开发的对话式语言模型。
123
+ # EduChat的工具
124
+ # - Web search: Disable.
125
+ # - Calculators: Disable.
126
+ # EduChat的能力
127
+ # - Inner Thought: Enable.
128
+ # 对话主题
129
+ # - General: Disable.
130
+ # - Psychology: Enable.
131
+ # - Socrates: Disable.'''"</s>"
132
+
133
+
134
+ print('Loading model...')
135
+
136
+ model = get_specific_model("models/ecnu-icalk/educhat-sft-002-7b")
137
+
138
+ model.half().cuda()
139
+ model.gradient_checkpointing_enable() # reduce number of stored activations
140
+
141
+ print('Loading tokenizer...')
142
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
143
+
144
+ tokenizer.add_special_tokens(
145
+ {
146
+ "pad_token": "</s>",
147
+ "eos_token": "</s>",
148
+ "sep_token": "<s>",
149
+ }
150
+ )
151
+ additional_special_tokens = (
152
+ []
153
+ if "additional_special_tokens" not in tokenizer.special_tokens_map
154
+ else tokenizer.special_tokens_map["additional_special_tokens"]
155
+ )
156
+ additional_special_tokens = list(
157
+ set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
158
+
159
+ print("additional_special_tokens:", additional_special_tokens)
160
+
161
+ tokenizer.add_special_tokens(
162
+ {"additional_special_tokens": additional_special_tokens})
163
+
164
+ if args.per_digit_tokens:
165
+ tokenizer._tokenizer.pre_processor = pre_tokenizers.Digits(True)
166
+
167
+ human_token_id = tokenizer.additional_special_tokens_ids[
168
+ tokenizer.additional_special_tokens.index(QA_SPECIAL_TOKENS["Question"])
169
+ ]
170
+
171
+ print('Type "quit" to exit')
172
+ print("Press Control + C to restart conversation (spam to exit)")
173
+
174
+ conversation_history = []
175
+
176
+
177
+ """Override Chatbot.postprocess"""
178
+
179
+
180
+ def postprocess(self, y):
181
+ if y is None:
182
+ return []
183
+ for i, (message, response) in enumerate(y):
184
+ y[i] = (
185
+ None if message is None else mdtex2html.convert((message)),
186
+ None if response is None else mdtex2html.convert(response),
187
+ )
188
+ return y
189
+
190
+
191
+ gr.Chatbot.postprocess = postprocess
192
+
193
+
194
+ def parse_text(text):
195
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
196
+ lines = text.split("\n")
197
+ lines = [line for line in lines if line != ""]
198
+ count = 0
199
+ for i, line in enumerate(lines):
200
+ if "```" in line:
201
+ count += 1
202
+ items = line.split('`')
203
+ if count % 2 == 1:
204
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
205
+ else:
206
+ lines[i] = f'<br></code></pre>'
207
+ else:
208
+ if i > 0:
209
+ if count % 2 == 1:
210
+ line = line.replace("`", "\`")
211
+ line = line.replace("<", "&lt;")
212
+ line = line.replace(">", "&gt;")
213
+ line = line.replace(" ", "&nbsp;")
214
+ line = line.replace("*", "&ast;")
215
+ line = line.replace("_", "&lowbar;")
216
+ line = line.replace("-", "&#45;")
217
+ line = line.replace(".", "&#46;")
218
+ line = line.replace("!", "&#33;")
219
+ line = line.replace("(", "&#40;")
220
+ line = line.replace(")", "&#41;")
221
+ line = line.replace("$", "&#36;")
222
+ lines[i] = "<br>"+line
223
+ text = "".join(lines)
224
+ return text
225
+
226
+
227
+ def predict(input, chatbot, max_length, top_p, temperature, history):
228
+ query = parse_text(input)
229
+ chatbot.append((query, ""))
230
+ conversation_history = []
231
+ for i, (old_query, response) in enumerate(history):
232
+ conversation_history.append(old_query)
233
+ conversation_history.append(response)
234
+
235
+ conversation_history.append(query)
236
+
237
+ query_str = "".join(format_pairs(conversation_history,
238
+ tokenizer.eos_token, add_initial_reply_token=True))
239
+
240
+ if system_prefix:
241
+ query_str = system_prefix + query_str
242
+ print("query:", query_str)
243
+
244
+ batch = tokenizer.encode(
245
+ query_str,
246
+ return_tensors="pt",
247
+ )
248
+
249
+ with torch.cuda.amp.autocast():
250
+ out = model.generate(
251
+ input_ids=batch.to(model.device),
252
+ # The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
253
+ max_new_tokens=args.max_new_tokens,
254
+ do_sample=args.do_sample,
255
+ max_length=max_length,
256
+ top_k=args.top_k,
257
+ top_p=top_p,
258
+ temperature=temperature,
259
+ eos_token_id=tokenizer.eos_token_id,
260
+ pad_token_id=tokenizer.eos_token_id,
261
+ )
262
+
263
+ if out[0][-1] == tokenizer.eos_token_id:
264
+ response = out[0][:-1]
265
+ else:
266
+ response = out[0]
267
+
268
+ response = tokenizer.decode(out[0]).split(QA_SPECIAL_TOKENS["Answer"])[-1]
269
+
270
+ conversation_history.append(response)
271
+
272
+ with open("./educhat_query_record.txt", 'a+') as f:
273
+ f.write(str(conversation_history) + '\n')
274
+
275
+ chatbot[-1] = (query, parse_text(response))
276
+ history = history + [(query, response)]
277
+ print(f"chatbot is {chatbot}")
278
+ print(f"history is {history}")
279
+
280
+ return chatbot, history
281
+
282
+
283
+ def reset_user_input():
284
+ return gr.update(value='')
285
+
286
+
287
+ def reset_state():
288
+ return [], []
289
+
290
+
291
+ with gr.Blocks() as demo:
292
+ gr.HTML("""<h1 align="center">欢迎使用 EduChat 人工智能助手!</h1>""")
293
+
294
+ chatbot = gr.Chatbot()
295
+ with gr.Row():
296
+ with gr.Column(scale=4):
297
+ with gr.Column(scale=12):
298
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
299
+ container=False)
300
+ with gr.Column(min_width=32, scale=1):
301
+ submitBtn = gr.Button("Submit", variant="primary")
302
+ with gr.Column(scale=1):
303
+ emptyBtn = gr.Button("Clear History")
304
+ max_length = gr.Slider(
305
+ 0, 2048, value=2048, step=1.0, label="Maximum length", interactive=True)
306
+ top_p = gr.Slider(0, 1, value=0.2, step=0.01,
307
+ label="Top P", interactive=True)
308
+ temperature = gr.Slider(
309
+ 0, 1, value=1, step=0.01, label="Temperature", interactive=True)
310
+
311
+ history = gr.State([]) # (message, bot_message)
312
+
313
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
314
+ show_progress=True)
315
+ submitBtn.click(reset_user_input, [], [user_input])
316
+
317
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
318
+
319
+ demo.queue().launch(inbrowser=True, share=True)