wangrongsheng commited on
Commit
c665e1d
1 Parent(s): 80a4580

Upload web_demo_mm.py

Browse files
Files changed (1) hide show
  1. webui/web_demo_mm.py +239 -0
webui/web_demo_mm.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """A simple web interactive chat demo based on gradio."""
7
+
8
+ from argparse import ArgumentParser
9
+ from pathlib import Path
10
+
11
+ import copy
12
+ import gradio as gr
13
+ import os
14
+ import re
15
+ import secrets
16
+ import tempfile
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+ from transformers.generation import GenerationConfig
19
+
20
+ DEFAULT_CKPT_PATH = './Qwen-VL-Chat'
21
+ BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
22
+ PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
23
+
24
+
25
+ def _get_args():
26
+ parser = ArgumentParser()
27
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
28
+ help="Checkpoint name or path, default to %(default)r")
29
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
30
+
31
+ parser.add_argument("--share", action="store_true", default=False,
32
+ help="Create a publicly shareable link for the interface.")
33
+ parser.add_argument("--inbrowser", action="store_true", default=False,
34
+ help="Automatically launch the interface in a new tab on the default browser.")
35
+ parser.add_argument("--server-port", type=int, default=7860,
36
+ help="Demo server port.")
37
+ parser.add_argument("--server-name", type=str, default="0.0.0.0",
38
+ help="Demo server name.")
39
+
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def _load_model_tokenizer(args):
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
47
+ )
48
+
49
+ if args.cpu_only:
50
+ device_map = "cpu"
51
+ else:
52
+ device_map = "cuda"
53
+
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ args.checkpoint_path,
56
+ device_map=device_map,
57
+ trust_remote_code=True,
58
+ resume_download=True,
59
+ ).eval()
60
+ model.generation_config = GenerationConfig.from_pretrained(
61
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
62
+ )
63
+
64
+ return model, tokenizer
65
+
66
+
67
+ def _parse_text(text):
68
+ lines = text.split("\n")
69
+ lines = [line for line in lines if line != ""]
70
+ count = 0
71
+ for i, line in enumerate(lines):
72
+ if "```" in line:
73
+ count += 1
74
+ items = line.split("`")
75
+ if count % 2 == 1:
76
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
77
+ else:
78
+ lines[i] = f"<br></code></pre>"
79
+ else:
80
+ if i > 0:
81
+ if count % 2 == 1:
82
+ line = line.replace("`", r"\`")
83
+ line = line.replace("<", "&lt;")
84
+ line = line.replace(">", "&gt;")
85
+ line = line.replace(" ", "&nbsp;")
86
+ line = line.replace("*", "&ast;")
87
+ line = line.replace("_", "&lowbar;")
88
+ line = line.replace("-", "&#45;")
89
+ line = line.replace(".", "&#46;")
90
+ line = line.replace("!", "&#33;")
91
+ line = line.replace("(", "&#40;")
92
+ line = line.replace(")", "&#41;")
93
+ line = line.replace("$", "&#36;")
94
+ lines[i] = "<br>" + line
95
+ text = "".join(lines)
96
+ return text
97
+
98
+
99
+ def _launch_demo(args, model, tokenizer):
100
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
101
+ Path(tempfile.gettempdir()) / "gradio"
102
+ )
103
+
104
+ def predict(_chatbot, task_history):
105
+ chat_query = _chatbot[-1][0]
106
+ query = task_history[-1][0]
107
+ print("User: " + _parse_text(query))
108
+ history_cp = copy.deepcopy(task_history)
109
+ full_response = ""
110
+
111
+ history_filter = []
112
+ pic_idx = 1
113
+ pre = ""
114
+ for i, (q, a) in enumerate(history_cp):
115
+ if isinstance(q, (tuple, list)):
116
+ q = f'Picture {pic_idx}: <img>{q[0]}</img>'
117
+ pre += q + '\n'
118
+ pic_idx += 1
119
+ else:
120
+ pre += q
121
+ history_filter.append((pre, a))
122
+ pre = ""
123
+ history, message = history_filter[:-1], history_filter[-1][0]
124
+ response, history = model.chat(tokenizer, message, history=history)
125
+ image = tokenizer.draw_bbox_on_latest_picture(response, history)
126
+ if image is not None:
127
+ temp_dir = secrets.token_hex(20)
128
+ temp_dir = Path(uploaded_file_dir) / temp_dir
129
+ temp_dir.mkdir(exist_ok=True, parents=True)
130
+ name = f"tmp{secrets.token_hex(5)}.jpg"
131
+ filename = temp_dir / name
132
+ image.save(str(filename))
133
+ _chatbot[-1] = (_parse_text(chat_query), (str(filename),))
134
+ chat_response = response.replace("<ref>", "")
135
+ chat_response = chat_response.replace(r"</ref>", "")
136
+ chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
137
+ if chat_response != "":
138
+ _chatbot.append((None, chat_response))
139
+ else:
140
+ _chatbot[-1] = (_parse_text(chat_query), response)
141
+ full_response = _parse_text(response)
142
+
143
+ task_history[-1] = (query, full_response)
144
+ print("Qwen-VL-Chat: " + _parse_text(full_response))
145
+ return _chatbot
146
+
147
+ def regenerate(_chatbot, task_history):
148
+ if not task_history:
149
+ return _chatbot
150
+ item = task_history[-1]
151
+ if item[1] is None:
152
+ return _chatbot
153
+ task_history[-1] = (item[0], None)
154
+ chatbot_item = _chatbot.pop(-1)
155
+ if chatbot_item[0] is None:
156
+ _chatbot[-1] = (_chatbot[-1][0], None)
157
+ else:
158
+ _chatbot.append((chatbot_item[0], None))
159
+ return predict(_chatbot, task_history)
160
+
161
+ def add_text(history, task_history, text):
162
+ task_text = text
163
+ if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
164
+ task_text = text[:-1]
165
+ history = history + [(_parse_text(text), None)]
166
+ task_history = task_history + [(task_text, None)]
167
+ return history, task_history, ""
168
+
169
+ def add_file(history, task_history, file):
170
+ history = history + [((file.name,), None)]
171
+ task_history = task_history + [((file.name,), None)]
172
+ return history, task_history
173
+
174
+ def reset_user_input():
175
+ return gr.update(value="")
176
+
177
+ def reset_state(task_history):
178
+ task_history.clear()
179
+ return []
180
+
181
+ with gr.Blocks() as demo:
182
+ gr.Markdown("""\
183
+ <p align="center"><img src="https://modelscope.cn/api/v1/models/qwen/Qwen-7B-Chat/repo?
184
+ Revision=master&FilePath=assets/logo.jpeg&View=true" style="height: 80px"/><p>""")
185
+ gr.Markdown("""<center><font size=8>Qwen-VL-Chat Bot</center>""")
186
+ gr.Markdown(
187
+ """\
188
+ <center><font size=3>This WebUI is based on Qwen-VL-Chat, developed by Alibaba Cloud. \
189
+ (本WebUI基于Qwen-VL-Chat打造,实现聊天机器人功能。)</center>""")
190
+ gr.Markdown("""\
191
+ <center><font size=4>Qwen-VL <a href="https://modelscope.cn/models/qwen/Qwen-VL/summary">🤖 </a>
192
+ | <a href="https://huggingface.co/Qwen/Qwen-VL">🤗</a>&nbsp |
193
+ Qwen-VL-Chat <a href="https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary">🤖 </a> |
194
+ <a href="https://huggingface.co/Qwen/Qwen-VL-Chat">🤗</a>&nbsp |
195
+ &nbsp<a href="https://github.com/QwenLM/Qwen-VL">Github</a></center>""")
196
+
197
+ chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=750)
198
+ query = gr.Textbox(lines=2, label='Input')
199
+ task_history = gr.State([])
200
+
201
+ with gr.Row():
202
+ empty_bin = gr.Button("🧹 Clear History (清除历史)")
203
+ submit_btn = gr.Button("🚀 Submit (发送)")
204
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
205
+ addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image"])
206
+
207
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
208
+ predict, [chatbot, task_history], [chatbot], show_progress=True
209
+ )
210
+ submit_btn.click(reset_user_input, [], [query])
211
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
212
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
213
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
214
+
215
+ gr.Markdown("""\
216
+ <font size=2>Note: This demo is governed by the original license of Qwen-VL. \
217
+ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
218
+ including hate speech, violence, pornography, deception, etc. \
219
+ (注:本演示受Qwen-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
220
+ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
221
+
222
+ demo.queue().launch(
223
+ share=args.share,
224
+ inbrowser=args.inbrowser,
225
+ server_port=args.server_port,
226
+ server_name=args.server_name,
227
+ )
228
+
229
+
230
+ def main():
231
+ args = _get_args()
232
+
233
+ model, tokenizer = _load_model_tokenizer(args)
234
+
235
+ _launch_demo(args, model, tokenizer)
236
+
237
+
238
+ if __name__ == '__main__':
239
+ main()