edwardjiang commited on
Commit
20272bc
1 Parent(s): 7675dc6

Create educhat_gradio.py

Browse files
Files changed (1) hide show
  1. educhat_gradio.py +323 -0
educhat_gradio.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+
4
+ import torch
5
+ import transformers
6
+ from distutils.util import strtobool
7
+ from tokenizers import pre_tokenizers
8
+
9
+ from transformers.generation.utils import logger
10
+ import mdtex2html
11
+ import gradio as gr
12
+ import warnings
13
+ import os
14
+
15
+ logger.setLevel("ERROR")
16
+ warnings.filterwarnings("ignore")
17
+
18
+ import os
19
+ os.system("export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:20")
20
+ os.system("export batch_size=1")
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+
25
+ def _strtobool(x):
26
+ return bool(strtobool(x))
27
+
28
+
29
+ QA_SPECIAL_TOKENS = {
30
+ "Question": "<|prompter|>",
31
+ "Answer": "<|assistant|>",
32
+ "System": "<|system|>",
33
+ "StartPrefix": "<|prefix_begin|>",
34
+ "EndPrefix": "<|prefix_end|>",
35
+ "InnerThought": "<|inner_thoughts|>",
36
+ "EndOfThought": "<eot>"
37
+ }
38
+
39
+
40
+ def format_pairs(pairs, eos_token, add_initial_reply_token=False):
41
+ conversations = [
42
+ "{}{}{}".format(
43
+ QA_SPECIAL_TOKENS["Question" if i % 2 == 0 else "Answer"], pairs[i], eos_token)
44
+ for i in range(len(pairs))
45
+ ]
46
+ if add_initial_reply_token:
47
+ conversations.append(QA_SPECIAL_TOKENS["Answer"])
48
+ return conversations
49
+
50
+
51
+ def format_system_prefix(prefix, eos_token):
52
+ return "{}{}{}".format(
53
+ QA_SPECIAL_TOKENS["System"],
54
+ prefix,
55
+ eos_token,
56
+ )
57
+
58
+
59
+ def get_specific_model(
60
+ model_name, seq2seqmodel=False, without_head=False, cache_dir=".cache", quantization=False, **kwargs
61
+ ):
62
+ # encoder-decoder support for Flan-T5 like models
63
+ # for now, we can use an argument but in the future,
64
+ # we can automate this
65
+
66
+ model = transformers.LlamaForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, ).half().cuda()
67
+
68
+ return model
69
+
70
+
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("--model_path", type=str, required=True)
73
+ parser.add_argument("--max_new_tokens", type=int, default=200)
74
+ parser.add_argument("--top_k", type=int, default=40)
75
+ parser.add_argument("--do_sample", type=_strtobool, default=True)
76
+ # parser.add_argument("--system_prefix", type=str, default=None)
77
+ parser.add_argument("--per-digit-tokens", action="store_true")
78
+
79
+
80
+ args = parser.parse_args()
81
+
82
+ # # 开放问答
83
+ # system_prefix = \
84
+ # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
85
+ # - EduChat是一个由华东师范大学开发的对话式语言模型。
86
+ # EduChat的工具
87
+ # - Web search: Disable.
88
+ # - Calculators: Disable.
89
+ # EduChat的能力
90
+ # - Inner Thought: Disable.
91
+ # 对话主题
92
+ # - General: Enable.
93
+ # - Psychology: Disable.
94
+ # - Socrates: Disable.'''"</s>"
95
+
96
+ # # 启发式教学
97
+ # system_prefix = \
98
+ # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
99
+ # - EduChat是一个由华东师范大学开发的对话式语言模型。
100
+ # EduChat的工具
101
+ # - Web search: Disable.
102
+ # - Calculators: Disable.
103
+ # EduChat的能力
104
+ # - Inner Thought: Disable.
105
+ # 对话主题
106
+ # - General: Disable.
107
+ # - Psychology: Disable.
108
+ # - Socrates: Enable.'''"</s>"
109
+
110
+ # 情感支持
111
+ system_prefix = \
112
+ "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
113
+ - EduChat是一个由华东师范大学开发的对话式语言模型。
114
+ EduChat的工具
115
+ - Web search: Disable.
116
+ - Calculators: Disable.
117
+ EduChat的能力
118
+ - Inner Thought: Disable.
119
+ 对话主题
120
+ - General: Disable.
121
+ - Psychology: Enable.
122
+ - Socrates: Disable.'''"</s>"
123
+
124
+ # # 情感支持(with InnerThought)
125
+ # system_prefix = \
126
+ # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
127
+ # - EduChat是一个由华东师范大学开发的对话式语言模型。
128
+ # EduChat的工具
129
+ # - Web search: Disable.
130
+ # - Calculators: Disable.
131
+ # EduChat的能力
132
+ # - Inner Thought: Enable.
133
+ # 对话主题
134
+ # - General: Disable.
135
+ # - Psychology: Enable.
136
+ # - Socrates: Disable.'''"</s>"
137
+
138
+
139
+ print('Loading model......')
140
+
141
+ model = get_specific_model(args.model_path)
142
+
143
+ model.gradient_checkpointing_enable() # reduce number of stored activations
144
+
145
+ print('Loading tokenizer...')
146
+ tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
147
+
148
+ tokenizer.add_special_tokens(
149
+ {
150
+ "pad_token": "</s>",
151
+ "eos_token": "</s>",
152
+ "sep_token": "<s>",
153
+ }
154
+ )
155
+ additional_special_tokens = (
156
+ []
157
+ if "additional_special_tokens" not in tokenizer.special_tokens_map
158
+ else tokenizer.special_tokens_map["additional_special_tokens"]
159
+ )
160
+ additional_special_tokens = list(
161
+ set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
162
+
163
+ print("additional_special_tokens:", additional_special_tokens)
164
+
165
+ tokenizer.add_special_tokens(
166
+ {"additional_special_tokens": additional_special_tokens})
167
+
168
+ if args.per_digit_tokens:
169
+ tokenizer._tokenizer.pre_processor = pre_tokenizers.Digits(True)
170
+
171
+ human_token_id = tokenizer.additional_special_tokens_ids[
172
+ tokenizer.additional_special_tokens.index(QA_SPECIAL_TOKENS["Question"])
173
+ ]
174
+
175
+ print('Type "quit" to exit')
176
+ print("Press Control + C to restart conversation (spam to exit)")
177
+
178
+ conversation_history = []
179
+
180
+
181
+ """Override Chatbot.postprocess"""
182
+
183
+
184
+ def postprocess(self, y):
185
+ if y is None:
186
+ return []
187
+ for i, (message, response) in enumerate(y):
188
+ y[i] = (
189
+ None if message is None else mdtex2html.convert((message)),
190
+ None if response is None else mdtex2html.convert(response),
191
+ )
192
+ return y
193
+
194
+
195
+ gr.Chatbot.postprocess = postprocess
196
+
197
+
198
+ def parse_text(text):
199
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
200
+ lines = text.split("\n")
201
+ lines = [line for line in lines if line != ""]
202
+ count = 0
203
+ for i, line in enumerate(lines):
204
+ if "```" in line:
205
+ count += 1
206
+ items = line.split('`')
207
+ if count % 2 == 1:
208
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
209
+ else:
210
+ lines[i] = f'<br></code></pre>'
211
+ else:
212
+ if i > 0:
213
+ if count % 2 == 1:
214
+ line = line.replace("`", "\`")
215
+ line = line.replace("<", "&lt;")
216
+ line = line.replace(">", "&gt;")
217
+ line = line.replace(" ", "&nbsp;")
218
+ line = line.replace("*", "&ast;")
219
+ line = line.replace("_", "&lowbar;")
220
+ line = line.replace("-", "&#45;")
221
+ line = line.replace(".", "&#46;")
222
+ line = line.replace("!", "&#33;")
223
+ line = line.replace("(", "&#40;")
224
+ line = line.replace(")", "&#41;")
225
+ line = line.replace("$", "&#36;")
226
+ lines[i] = "<br>"+line
227
+ text = "".join(lines)
228
+ return text
229
+
230
+
231
+ def predict(input, chatbot, max_length, top_p, temperature, history):
232
+ query = parse_text(input)
233
+ chatbot.append((query, ""))
234
+ conversation_history = []
235
+ for i, (old_query, response) in enumerate(history):
236
+ conversation_history.append(old_query)
237
+ conversation_history.append(response)
238
+
239
+ conversation_history.append(query)
240
+
241
+ query_str = "".join(format_pairs(conversation_history,
242
+ tokenizer.eos_token, add_initial_reply_token=True))
243
+
244
+ if system_prefix:
245
+ query_str = system_prefix + query_str
246
+ print("query:", query_str)
247
+
248
+ batch = tokenizer.encode(
249
+ query_str,
250
+ return_tensors="pt",
251
+ )
252
+
253
+ with torch.cuda.amp.autocast():
254
+ out = model.generate(
255
+ input_ids=batch.to(model.device),
256
+ # The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
257
+ max_new_tokens=args.max_new_tokens,
258
+ do_sample=args.do_sample,
259
+ max_length=max_length,
260
+ top_k=args.top_k,
261
+ top_p=top_p,
262
+ temperature=temperature,
263
+ eos_token_id=tokenizer.eos_token_id,
264
+ pad_token_id=tokenizer.eos_token_id,
265
+ )
266
+
267
+ if out[0][-1] == tokenizer.eos_token_id:
268
+ response = out[0][:-1]
269
+ else:
270
+ response = out[0]
271
+
272
+ response = tokenizer.decode(out[0]).split(QA_SPECIAL_TOKENS["Answer"])[-1]
273
+
274
+ conversation_history.append(response)
275
+
276
+ with open("./educhat_query_record.txt", 'a+') as f:
277
+ f.write(str(conversation_history) + '\n')
278
+
279
+ chatbot[-1] = (query, parse_text(response))
280
+ history = history + [(query, response)]
281
+ print(f"chatbot is {chatbot}")
282
+ print(f"history is {history}")
283
+
284
+ return chatbot, history
285
+
286
+
287
+ def reset_user_input():
288
+ return gr.update(value='')
289
+
290
+
291
+ def reset_state():
292
+ return [], []
293
+
294
+
295
+ with gr.Blocks() as demo:
296
+ gr.HTML("""<h1 align="center">欢迎使用 EduChat 人工智能助手!</h1>""")
297
+
298
+ chatbot = gr.Chatbot()
299
+ with gr.Row():
300
+ with gr.Column(scale=4):
301
+ with gr.Column(scale=12):
302
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
303
+ container=False)
304
+ with gr.Column(min_width=32, scale=1):
305
+ submitBtn = gr.Button("Submit", variant="primary")
306
+ with gr.Column(scale=1):
307
+ emptyBtn = gr.Button("Clear History")
308
+ max_length = gr.Slider(
309
+ 0, 2048, value=2048, step=1.0, label="Maximum length", interactive=True)
310
+ top_p = gr.Slider(0, 1, value=0.2, step=0.01,
311
+ label="Top P", interactive=True)
312
+ temperature = gr.Slider(
313
+ 0, 1, value=1, step=0.01, label="Temperature", interactive=True)
314
+
315
+ history = gr.State([]) # (message, bot_message)
316
+
317
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
318
+ show_progress=True)
319
+ submitBtn.click(reset_user_input, [], [user_input])
320
+
321
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
322
+
323
+ demo.queue().launch(inbrowser=True, share=True)