initial commit
Browse files- web_demo.py +127 -0
web_demo.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from model import is_chinese, get_infer_setting, generate_input, chat
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
|
9 |
+
input_para = {
|
10 |
+
"max_length": 2048,
|
11 |
+
"min_length": 50,
|
12 |
+
"temperature": 0.8,
|
13 |
+
"top_p": 0.4,
|
14 |
+
"top_k": 100,
|
15 |
+
"repetition_penalty": 1.2
|
16 |
+
}
|
17 |
+
input_para.update(request_data)
|
18 |
+
|
19 |
+
input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
|
20 |
+
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
|
21 |
+
with torch.no_grad():
|
22 |
+
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
|
23 |
+
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
|
24 |
+
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
|
25 |
+
return answer
|
26 |
+
|
27 |
+
|
28 |
+
def request_model(input_text, temperature, top_p, image_prompt, result_previous):
|
29 |
+
result_text = [(ele[0], ele[1]) for ele in result_previous]
|
30 |
+
for i in range(len(result_text)-1, -1, -1):
|
31 |
+
if result_text[i][0] == "" or result_text[i][1] == "":
|
32 |
+
del result_text[i]
|
33 |
+
print(f"history {result_text}")
|
34 |
+
|
35 |
+
is_zh = is_chinese(input_text)
|
36 |
+
if image_prompt is None:
|
37 |
+
if is_zh:
|
38 |
+
result_text.append((input_text, '图片为空!请上传图片并重试。'))
|
39 |
+
else:
|
40 |
+
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
|
41 |
+
return input_text, result_text
|
42 |
+
elif input_text == "":
|
43 |
+
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
|
44 |
+
return "", result_text
|
45 |
+
|
46 |
+
request_para = {"temperature": temperature, "top_p": top_p}
|
47 |
+
image = Image.open(image_prompt)
|
48 |
+
try:
|
49 |
+
answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
|
50 |
+
except Exception as e:
|
51 |
+
print(f"error: {e}")
|
52 |
+
if is_zh:
|
53 |
+
result_text.append((input_text, '超时!请稍等几分钟再重试。'))
|
54 |
+
else:
|
55 |
+
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
|
56 |
+
return "", result_text
|
57 |
+
|
58 |
+
result_text.append((input_text, answer))
|
59 |
+
print(result_text)
|
60 |
+
return "", result_text
|
61 |
+
|
62 |
+
|
63 |
+
DESCRIPTION = '''# <a href="https://github.com/THUDM/VisualGLM-6B">VisualGLM</a>'''
|
64 |
+
|
65 |
+
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
|
66 |
+
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
|
67 |
+
|
68 |
+
NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
|
69 |
+
|
70 |
+
|
71 |
+
def clear_fn(value):
|
72 |
+
return "", [("", "Hi, What do you want to know about this image?")], None
|
73 |
+
|
74 |
+
def clear_fn2(value):
|
75 |
+
return [("", "Hi, What do you want to know about this image?")]
|
76 |
+
|
77 |
+
|
78 |
+
def main(args):
|
79 |
+
gr.close_all()
|
80 |
+
global model, tokenizer
|
81 |
+
model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)
|
82 |
+
|
83 |
+
with gr.Blocks(css='style.css') as demo:
|
84 |
+
gr.Markdown(DESCRIPTION)
|
85 |
+
with gr.Row():
|
86 |
+
with gr.Column(scale=4.5):
|
87 |
+
with gr.Group():
|
88 |
+
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
|
89 |
+
with gr.Row():
|
90 |
+
run_button = gr.Button('Generate')
|
91 |
+
clear_button = gr.Button('Clear')
|
92 |
+
|
93 |
+
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
|
94 |
+
with gr.Row():
|
95 |
+
temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
|
96 |
+
top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
|
97 |
+
with gr.Group():
|
98 |
+
with gr.Row():
|
99 |
+
maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
|
100 |
+
with gr.Column(scale=5.5):
|
101 |
+
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
|
102 |
+
|
103 |
+
gr.Markdown(NOTES)
|
104 |
+
|
105 |
+
print(gr.__version__)
|
106 |
+
run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
|
107 |
+
outputs=[input_text, result_text])
|
108 |
+
input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
|
109 |
+
outputs=[input_text, result_text])
|
110 |
+
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
|
111 |
+
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
|
112 |
+
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
|
113 |
+
|
114 |
+
print(gr.__version__)
|
115 |
+
|
116 |
+
demo.queue(concurrency_count=10)
|
117 |
+
demo.launch(share=args.share)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
import argparse
|
122 |
+
parser = argparse.ArgumentParser()
|
123 |
+
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
|
124 |
+
parser.add_argument("--share", action="store_true")
|
125 |
+
args = parser.parse_args()
|
126 |
+
args.share = "True"
|
127 |
+
main(args)
|