xu song commited on
Commit
d72c532
·
1 Parent(s): adffeb2
app.py CHANGED
@@ -1,63 +1,180 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
41
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
44
  """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ 来自 https://github.com/OpenLMLab/MOSS/blob/main/moss_web_demo_gradio.py
3
+
4
+
5
+ # 单卡报错
6
+ python moss_web_demo_gradio.py --model_name fnlp/moss-moon-003-sft --gpu 0,1,2,3
7
+
8
+ # TODO
9
+ - 第一句:
10
+ - 代码和表格的预览
11
+ - 可编辑chatbot:https://github.com/gradio-app/gradio/issues/4444
12
  """
13
+
14
+ from transformers.generation.utils import logger
15
+
16
+ import gradio as gr
17
+ import argparse
18
+ import warnings
19
+ import torch
20
+ import os
21
+ # from moss_util import generate_query
22
+ from models.qwen2_util import bot
23
+ # generate_query = None
24
+
25
+ # gr.ChatInterface
26
+
27
+ # from gpt35 import build_message_for_gpt35, send_one_query
28
+
29
+ #
30
+ # def postprocess(self, y):
31
+ # if y is None:
32
+ # return []
33
+ # for i, (message, response) in enumerate(y):
34
+ # y[i] = (
35
+ # None if message is None else mdtex2html.convert((message)),
36
+ # None if response is None else mdtex2html.convert(response),
37
+ # )
38
+ # return y
39
+ #
40
+ #
41
+ # gr.Chatbot.postprocess = postprocess
42
+
43
+
44
+ def parse_text(text):
45
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
46
+ lines = text.split("\n")
47
+ lines = [line for line in lines if line != ""]
48
+ count = 0
49
+ for i, line in enumerate(lines):
50
+ if "```" in line:
51
+ count += 1
52
+ items = line.split('`')
53
+ if count % 2 == 1:
54
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
55
+ else:
56
+ lines[i] = f'<br></code></pre>'
57
+ else:
58
+ if i > 0:
59
+ if count % 2 == 1:
60
+ line = line.replace("`", "\`")
61
+ line = line.replace("<", "&lt;")
62
+ line = line.replace(">", "&gt;")
63
+ line = line.replace(" ", "&nbsp;")
64
+ line = line.replace("*", "&ast;")
65
+ line = line.replace("_", "&lowbar;")
66
+ line = line.replace("-", "&#45;")
67
+ line = line.replace(".", "&#46;")
68
+ line = line.replace("!", "&#33;")
69
+ line = line.replace("(", "&#40;")
70
+ line = line.replace(")", "&#41;")
71
+ line = line.replace("$", "&#36;")
72
+ lines[i] = "<br>" + line
73
+ text = "".join(lines)
74
+ return text
75
 
76
 
77
+ def generate_query(chatbot, history):
78
+ if history and history[-1][1] is None: # 该生成response了
79
+ return None, chatbot, history
80
+ query = bot.generate_query(history)
81
+ # chatbot.append((query, ""))
82
+ chatbot.append((query, None))
83
+ history = history + [(query, None)]
84
+ return query, chatbot, history
 
85
 
86
+ def generate_response(query, chatbot, history):
87
+ """
88
+ 自动模式下:query is None,或者 query = history[-1][0]
89
+ 人工模式下:query 是任意值
90
+ :param query:
91
+ :param chatbot:
92
+ :param history:
93
+ :return:
94
+ """
95
+ # messages = build_message_for_gpt35(query, history)
96
+ # response, success = send_one_query(query, messages, model="gpt-35-turbo")
97
+ # response = response["choices"][0]["message"]["content"]
98
 
99
+ #
100
+ if history[-1][1] is not None or chatbot[-1][1] is not None:
101
+ return chatbot, history
102
 
103
+ if query is None:
104
+ query = history[-1][0]
105
+ response = bot.generate_response(query, history[:-1])
106
+ # chatbot.append((query, response))
107
+ history[-1] = (query, response)
108
+ chatbot[-1] = (query, response)
109
+ print(f"chatbot is {chatbot}")
110
+ print(f"history is {history}")
111
+ return chatbot, history
112
 
 
 
 
 
 
 
 
 
113
 
114
+ def reset_user_input():
115
+ return gr.update(value='')
116
+
117
+
118
+ def reset_state():
119
+ return [], []
120
+
121
 
122
  """
123
+ TODO: 使用说明
124
+
125
+ avatar_images
126
  """
127
+ with gr.Blocks() as demo:
128
+ gr.HTML("""<h1 align="center">欢迎使用 self chat 人工智能助手!</h1>""")
129
+
130
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
131
+ system = gr.Textbox(show_label=False, placeholder="You are a helpful assistant.")
132
+ chatbot = gr.Chatbot(avatar_images=("assets/profile.png", "assets/bot.png"))
133
+ with gr.Row():
134
+ with gr.Column(scale=4):
135
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10)
136
+ with gr.Row():
137
+ generate_query_btn = gr.Button("生成问题")
138
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
139
+ submit_btn = gr.Button("生成回复", variant="primary")
140
+ stop_btn = gr.Button("停止生成", variant="primary")
141
+ empty_btn = gr.Button("🧹 Clear History (清除历史)")
142
+ with gr.Column(scale=1):
143
+ # generate_query_btn = gr.Button("Generate First Query")
144
+
145
+ clear_btn = gr.Button("重置")
146
+ gr.Dropdown(
147
+ ["moss", "chatglm-2", "chatpdf"],
148
+ value="moss",
149
+ label="问题生成器",
150
+ # info="Will add more animals later!"
151
+ ),
152
+ gr.Dropdown(
153
+ ["moss", "chatglm-2", "gpt3.5-turbo"],
154
+ value="gpt3.5-turbo",
155
+ label="回复生成器",
156
+ # info="Will add more animals later!"
157
+ ),
158
+
159
+ history = gr.State([]) # (message, bot_message)
160
+
161
+ submit_btn.click(generate_response, [user_input, chatbot, history], [chatbot, history],
162
+ show_progress=True)
163
+ # submit_btn.click(reset_user_input, [], [user_input])
164
+
165
+ clear_btn.click(reset_state, outputs=[chatbot, history], show_progress=True)
166
+
167
+ generate_query_btn.click(generate_query, [chatbot, history], outputs=[user_input, chatbot, history], show_progress=True)
168
+
169
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
170
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
171
+ gr.Slider(
172
+ minimum=0.1,
173
+ maximum=1.0,
174
+ value=0.95,
175
+ step=0.05,
176
+ label="Top-p (nucleus sampling)",
177
+ ),
178
+
179
+ demo.queue().launch(share=False)
180
+ # demo.queue().launch(share=True)
assets/bot.png ADDED
assets/profile.png ADDED
assets/programmer.png ADDED
assets/robot (1).png ADDED
assets/robot.png ADDED
models/qwen2_util.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Qwen/Qwen2-0.5B-Instruct"
2
+
3
+ from threading import Thread
4
+ from simulator import Simulator
5
+
6
+ from transformers import TextIteratorStreamer
7
+
8
+
9
+ class Qwen2Simulator(Simulator):
10
+
11
+ def generate_query(self, history):
12
+
13
+ inputs = ""
14
+ if history:
15
+ messages = []
16
+ for query, response in history:
17
+ messages += [
18
+ {"role": "user", "content": query},
19
+ {"role": "assistant", "content": response},
20
+ ]
21
+
22
+ inputs += self.tokenizer.apply_chat_template(
23
+ messages,
24
+ tokenize=False,
25
+ add_generation_prompt=False,
26
+ )
27
+ inputs = inputs + "<|im_start|>user\n"
28
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device)
29
+ return self._generate(input_ids)
30
+ # for new_text in self._stream_generate(input_ids):
31
+ # yield new_text
32
+
33
+ def generate_response(self, query, history):
34
+ messages = []
35
+ for _query, _response in history:
36
+ if _response is None:
37
+ pass
38
+ messages += [
39
+ {"role": "user", "content": _query},
40
+ {"role": "assistant", "content": _response},
41
+ ]
42
+ messages.append({"role": "user", "content": query})
43
+
44
+ input_ids = self.tokenizer.apply_chat_template(
45
+ messages,
46
+ tokenize=True,
47
+ return_tensors="pt",
48
+ add_generation_prompt=True
49
+ ).to(self.model.device)
50
+ return self._generate(input_ids)
51
+ # for new_text in self._stream_generate(input_ids):
52
+ # yield new_text
53
+
54
+ def _generate(self, input_ids):
55
+
56
+ input_ids_length = input_ids.shape[-1]
57
+ response = self.model.generate(input_ids=input_ids, **self.generation_kwargs)
58
+ return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True)
59
+
60
+ def _stream_generate(self, input_ids):
61
+ streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=60.0,
62
+ skip_special_tokens=True)
63
+
64
+ stream_generation_kwargs = dict(
65
+ input_ids=input_ids,
66
+ streamer=streamer
67
+ ).update(self.generation_kwargs)
68
+ thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs)
69
+ thread.start()
70
+
71
+ for new_text in streamer:
72
+ yield new_text
73
+
74
+
75
+ bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
76
+ # bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct")
77
+
78
+
79
+ #
80
+ # history = [["hi, what your name", "rhino"]]
81
+ # generated_query = bot.generate_query(history)
82
+ # for char in generated_query:
83
+ # print(char)
84
+ #
85
+ # bot.generate_response("1+2*3=", history)
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- huggingface_hub==0.22.2
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ transformers
3
+ torch
4
+ accelerate
simulator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+
3
+
4
+ class Simulator:
5
+
6
+ def __init__(self, model_name_or_path):
7
+ """
8
+ 在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
9
+ """
10
+
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
12
+ self.model = AutoModelForCausalLM.from_pretrained(
13
+ model_name_or_path,
14
+ torch_dtype="auto",
15
+ device_map="auto"
16
+ )
17
+ self.model.eval()
18
+
19
+ self.generation_kwargs = dict(
20
+ do_sample=False,
21
+ temperature=0.7,
22
+ max_length=500,
23
+ max_new_tokens=10
24
+ )
25
+
26
+ generation_kwargs = dict(
27
+
28
+ max_length=500,
29
+ max_new_tokens=200
30
+ )
31
+
32
+ print(1)
33
+
34
+ def generate_query(self, history):
35
+ """ user simulator
36
+ :param history:
37
+ :return:
38
+ """
39
+ raise NotImplementedError
40
+
41
+ def generate_response(self, input, history):
42
+ """ assistant simulator
43
+ :param input:
44
+ :param history:
45
+ :return:
46
+ """
47
+ raise NotImplementedError