xu song
commited on
Commit
·
d72c532
1
Parent(s):
adffeb2
update
Browse files- app.py +168 -51
- assets/bot.png +0 -0
- assets/profile.png +0 -0
- assets/programmer.png +0 -0
- assets/robot (1).png +0 -0
- assets/robot.png +0 -0
- models/qwen2_util.py +85 -0
- requirements.txt +4 -1
- simulator.py +47 -0
app.py
CHANGED
@@ -1,63 +1,180 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from huggingface_hub import InferenceClient
|
3 |
-
|
4 |
"""
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
"""
|
43 |
-
|
|
|
|
|
44 |
"""
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
来自 https://github.com/OpenLMLab/MOSS/blob/main/moss_web_demo_gradio.py
|
3 |
+
|
4 |
+
|
5 |
+
# 单卡报错
|
6 |
+
python moss_web_demo_gradio.py --model_name fnlp/moss-moon-003-sft --gpu 0,1,2,3
|
7 |
+
|
8 |
+
# TODO
|
9 |
+
- 第一句:
|
10 |
+
- 代码和表格的预览
|
11 |
+
- 可编辑chatbot:https://github.com/gradio-app/gradio/issues/4444
|
12 |
"""
|
13 |
+
|
14 |
+
from transformers.generation.utils import logger
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import argparse
|
18 |
+
import warnings
|
19 |
+
import torch
|
20 |
+
import os
|
21 |
+
# from moss_util import generate_query
|
22 |
+
from models.qwen2_util import bot
|
23 |
+
# generate_query = None
|
24 |
+
|
25 |
+
# gr.ChatInterface
|
26 |
+
|
27 |
+
# from gpt35 import build_message_for_gpt35, send_one_query
|
28 |
+
|
29 |
+
#
|
30 |
+
# def postprocess(self, y):
|
31 |
+
# if y is None:
|
32 |
+
# return []
|
33 |
+
# for i, (message, response) in enumerate(y):
|
34 |
+
# y[i] = (
|
35 |
+
# None if message is None else mdtex2html.convert((message)),
|
36 |
+
# None if response is None else mdtex2html.convert(response),
|
37 |
+
# )
|
38 |
+
# return y
|
39 |
+
#
|
40 |
+
#
|
41 |
+
# gr.Chatbot.postprocess = postprocess
|
42 |
+
|
43 |
+
|
44 |
+
def parse_text(text):
|
45 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
46 |
+
lines = text.split("\n")
|
47 |
+
lines = [line for line in lines if line != ""]
|
48 |
+
count = 0
|
49 |
+
for i, line in enumerate(lines):
|
50 |
+
if "```" in line:
|
51 |
+
count += 1
|
52 |
+
items = line.split('`')
|
53 |
+
if count % 2 == 1:
|
54 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
55 |
+
else:
|
56 |
+
lines[i] = f'<br></code></pre>'
|
57 |
+
else:
|
58 |
+
if i > 0:
|
59 |
+
if count % 2 == 1:
|
60 |
+
line = line.replace("`", "\`")
|
61 |
+
line = line.replace("<", "<")
|
62 |
+
line = line.replace(">", ">")
|
63 |
+
line = line.replace(" ", " ")
|
64 |
+
line = line.replace("*", "*")
|
65 |
+
line = line.replace("_", "_")
|
66 |
+
line = line.replace("-", "-")
|
67 |
+
line = line.replace(".", ".")
|
68 |
+
line = line.replace("!", "!")
|
69 |
+
line = line.replace("(", "(")
|
70 |
+
line = line.replace(")", ")")
|
71 |
+
line = line.replace("$", "$")
|
72 |
+
lines[i] = "<br>" + line
|
73 |
+
text = "".join(lines)
|
74 |
+
return text
|
75 |
|
76 |
|
77 |
+
def generate_query(chatbot, history):
|
78 |
+
if history and history[-1][1] is None: # 该生成response了
|
79 |
+
return None, chatbot, history
|
80 |
+
query = bot.generate_query(history)
|
81 |
+
# chatbot.append((query, ""))
|
82 |
+
chatbot.append((query, None))
|
83 |
+
history = history + [(query, None)]
|
84 |
+
return query, chatbot, history
|
|
|
85 |
|
86 |
+
def generate_response(query, chatbot, history):
|
87 |
+
"""
|
88 |
+
自动模式下:query is None,或者 query = history[-1][0]
|
89 |
+
人工模式下:query 是任意值
|
90 |
+
:param query:
|
91 |
+
:param chatbot:
|
92 |
+
:param history:
|
93 |
+
:return:
|
94 |
+
"""
|
95 |
+
# messages = build_message_for_gpt35(query, history)
|
96 |
+
# response, success = send_one_query(query, messages, model="gpt-35-turbo")
|
97 |
+
# response = response["choices"][0]["message"]["content"]
|
98 |
|
99 |
+
#
|
100 |
+
if history[-1][1] is not None or chatbot[-1][1] is not None:
|
101 |
+
return chatbot, history
|
102 |
|
103 |
+
if query is None:
|
104 |
+
query = history[-1][0]
|
105 |
+
response = bot.generate_response(query, history[:-1])
|
106 |
+
# chatbot.append((query, response))
|
107 |
+
history[-1] = (query, response)
|
108 |
+
chatbot[-1] = (query, response)
|
109 |
+
print(f"chatbot is {chatbot}")
|
110 |
+
print(f"history is {history}")
|
111 |
+
return chatbot, history
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
+
def reset_user_input():
|
115 |
+
return gr.update(value='')
|
116 |
+
|
117 |
+
|
118 |
+
def reset_state():
|
119 |
+
return [], []
|
120 |
+
|
121 |
|
122 |
"""
|
123 |
+
TODO: 使用说明
|
124 |
+
|
125 |
+
avatar_images
|
126 |
"""
|
127 |
+
with gr.Blocks() as demo:
|
128 |
+
gr.HTML("""<h1 align="center">欢迎使用 self chat 人工智能助手!</h1>""")
|
129 |
+
|
130 |
+
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
131 |
+
system = gr.Textbox(show_label=False, placeholder="You are a helpful assistant.")
|
132 |
+
chatbot = gr.Chatbot(avatar_images=("assets/profile.png", "assets/bot.png"))
|
133 |
+
with gr.Row():
|
134 |
+
with gr.Column(scale=4):
|
135 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10)
|
136 |
+
with gr.Row():
|
137 |
+
generate_query_btn = gr.Button("生成问题")
|
138 |
+
regen_btn = gr.Button("🤔️ Regenerate (重试)")
|
139 |
+
submit_btn = gr.Button("生成回复", variant="primary")
|
140 |
+
stop_btn = gr.Button("停止生成", variant="primary")
|
141 |
+
empty_btn = gr.Button("🧹 Clear History (清除历史)")
|
142 |
+
with gr.Column(scale=1):
|
143 |
+
# generate_query_btn = gr.Button("Generate First Query")
|
144 |
+
|
145 |
+
clear_btn = gr.Button("重置")
|
146 |
+
gr.Dropdown(
|
147 |
+
["moss", "chatglm-2", "chatpdf"],
|
148 |
+
value="moss",
|
149 |
+
label="问题生成器",
|
150 |
+
# info="Will add more animals later!"
|
151 |
+
),
|
152 |
+
gr.Dropdown(
|
153 |
+
["moss", "chatglm-2", "gpt3.5-turbo"],
|
154 |
+
value="gpt3.5-turbo",
|
155 |
+
label="回复生成器",
|
156 |
+
# info="Will add more animals later!"
|
157 |
+
),
|
158 |
+
|
159 |
+
history = gr.State([]) # (message, bot_message)
|
160 |
+
|
161 |
+
submit_btn.click(generate_response, [user_input, chatbot, history], [chatbot, history],
|
162 |
+
show_progress=True)
|
163 |
+
# submit_btn.click(reset_user_input, [], [user_input])
|
164 |
+
|
165 |
+
clear_btn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
166 |
+
|
167 |
+
generate_query_btn.click(generate_query, [chatbot, history], outputs=[user_input, chatbot, history], show_progress=True)
|
168 |
+
|
169 |
+
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
170 |
+
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
171 |
+
gr.Slider(
|
172 |
+
minimum=0.1,
|
173 |
+
maximum=1.0,
|
174 |
+
value=0.95,
|
175 |
+
step=0.05,
|
176 |
+
label="Top-p (nucleus sampling)",
|
177 |
+
),
|
178 |
+
|
179 |
+
demo.queue().launch(share=False)
|
180 |
+
# demo.queue().launch(share=True)
|
assets/bot.png
ADDED
assets/profile.png
ADDED
assets/programmer.png
ADDED
assets/robot (1).png
ADDED
assets/robot.png
ADDED
models/qwen2_util.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Qwen/Qwen2-0.5B-Instruct"
|
2 |
+
|
3 |
+
from threading import Thread
|
4 |
+
from simulator import Simulator
|
5 |
+
|
6 |
+
from transformers import TextIteratorStreamer
|
7 |
+
|
8 |
+
|
9 |
+
class Qwen2Simulator(Simulator):
|
10 |
+
|
11 |
+
def generate_query(self, history):
|
12 |
+
|
13 |
+
inputs = ""
|
14 |
+
if history:
|
15 |
+
messages = []
|
16 |
+
for query, response in history:
|
17 |
+
messages += [
|
18 |
+
{"role": "user", "content": query},
|
19 |
+
{"role": "assistant", "content": response},
|
20 |
+
]
|
21 |
+
|
22 |
+
inputs += self.tokenizer.apply_chat_template(
|
23 |
+
messages,
|
24 |
+
tokenize=False,
|
25 |
+
add_generation_prompt=False,
|
26 |
+
)
|
27 |
+
inputs = inputs + "<|im_start|>user\n"
|
28 |
+
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
|
29 |
+
return self._generate(input_ids)
|
30 |
+
# for new_text in self._stream_generate(input_ids):
|
31 |
+
# yield new_text
|
32 |
+
|
33 |
+
def generate_response(self, query, history):
|
34 |
+
messages = []
|
35 |
+
for _query, _response in history:
|
36 |
+
if _response is None:
|
37 |
+
pass
|
38 |
+
messages += [
|
39 |
+
{"role": "user", "content": _query},
|
40 |
+
{"role": "assistant", "content": _response},
|
41 |
+
]
|
42 |
+
messages.append({"role": "user", "content": query})
|
43 |
+
|
44 |
+
input_ids = self.tokenizer.apply_chat_template(
|
45 |
+
messages,
|
46 |
+
tokenize=True,
|
47 |
+
return_tensors="pt",
|
48 |
+
add_generation_prompt=True
|
49 |
+
).to(self.model.device)
|
50 |
+
return self._generate(input_ids)
|
51 |
+
# for new_text in self._stream_generate(input_ids):
|
52 |
+
# yield new_text
|
53 |
+
|
54 |
+
def _generate(self, input_ids):
|
55 |
+
|
56 |
+
input_ids_length = input_ids.shape[-1]
|
57 |
+
response = self.model.generate(input_ids=input_ids, **self.generation_kwargs)
|
58 |
+
return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)
|
59 |
+
|
60 |
+
def _stream_generate(self, input_ids):
|
61 |
+
streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=60.0,
|
62 |
+
skip_special_tokens=True)
|
63 |
+
|
64 |
+
stream_generation_kwargs = dict(
|
65 |
+
input_ids=input_ids,
|
66 |
+
streamer=streamer
|
67 |
+
).update(self.generation_kwargs)
|
68 |
+
thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs)
|
69 |
+
thread.start()
|
70 |
+
|
71 |
+
for new_text in streamer:
|
72 |
+
yield new_text
|
73 |
+
|
74 |
+
|
75 |
+
bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
|
76 |
+
# bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")
|
77 |
+
|
78 |
+
|
79 |
+
#
|
80 |
+
# history = [["hi, what your name", "rhino"]]
|
81 |
+
# generated_query = bot.generate_query(history)
|
82 |
+
# for char in generated_query:
|
83 |
+
# print(char)
|
84 |
+
#
|
85 |
+
# bot.generate_response("1+2*3=", history)
|
requirements.txt
CHANGED
@@ -1 +1,4 @@
|
|
1 |
-
huggingface_hub==0.22.2
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.22.2
|
2 |
+
transformers
|
3 |
+
torch
|
4 |
+
accelerate
|
simulator.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
|
3 |
+
|
4 |
+
class Simulator:
|
5 |
+
|
6 |
+
def __init__(self, model_name_or_path):
|
7 |
+
"""
|
8 |
+
在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
|
9 |
+
"""
|
10 |
+
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
12 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
13 |
+
model_name_or_path,
|
14 |
+
torch_dtype="auto",
|
15 |
+
device_map="auto"
|
16 |
+
)
|
17 |
+
self.model.eval()
|
18 |
+
|
19 |
+
self.generation_kwargs = dict(
|
20 |
+
do_sample=False,
|
21 |
+
temperature=0.7,
|
22 |
+
max_length=500,
|
23 |
+
max_new_tokens=10
|
24 |
+
)
|
25 |
+
|
26 |
+
generation_kwargs = dict(
|
27 |
+
|
28 |
+
max_length=500,
|
29 |
+
max_new_tokens=200
|
30 |
+
)
|
31 |
+
|
32 |
+
print(1)
|
33 |
+
|
34 |
+
def generate_query(self, history):
|
35 |
+
""" user simulator
|
36 |
+
:param history:
|
37 |
+
:return:
|
38 |
+
"""
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
def generate_response(self, input, history):
|
42 |
+
""" assistant simulator
|
43 |
+
:param input:
|
44 |
+
:param history:
|
45 |
+
:return:
|
46 |
+
"""
|
47 |
+
raise NotImplementedError
|