muxingyin commited on
Commit
e11b993
1 Parent(s): 5f5f6ff

initial commit

Browse files
Files changed (1) hide show
  1. web_demo.py +127 -0
web_demo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ import json
5
+ from model import is_chinese, get_infer_setting, generate_input, chat
6
+ import torch
7
+
8
+ def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
9
+ input_para = {
10
+ "max_length": 2048,
11
+ "min_length": 50,
12
+ "temperature": 0.8,
13
+ "top_p": 0.4,
14
+ "top_k": 100,
15
+ "repetition_penalty": 1.2
16
+ }
17
+ input_para.update(request_data)
18
+
19
+ input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
20
+ input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
21
+ with torch.no_grad():
22
+ answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
23
+ max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
24
+ top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
25
+ return answer
26
+
27
+
28
+ def request_model(input_text, temperature, top_p, image_prompt, result_previous):
29
+ result_text = [(ele[0], ele[1]) for ele in result_previous]
30
+ for i in range(len(result_text)-1, -1, -1):
31
+ if result_text[i][0] == "" or result_text[i][1] == "":
32
+ del result_text[i]
33
+ print(f"history {result_text}")
34
+
35
+ is_zh = is_chinese(input_text)
36
+ if image_prompt is None:
37
+ if is_zh:
38
+ result_text.append((input_text, '图片为空!请上传图片并重试。'))
39
+ else:
40
+ result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
41
+ return input_text, result_text
42
+ elif input_text == "":
43
+ result_text.append((input_text, 'Text empty! Please enter text and retry.'))
44
+ return "", result_text
45
+
46
+ request_para = {"temperature": temperature, "top_p": top_p}
47
+ image = Image.open(image_prompt)
48
+ try:
49
+ answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
50
+ except Exception as e:
51
+ print(f"error: {e}")
52
+ if is_zh:
53
+ result_text.append((input_text, '超时!请稍等几分钟再重试。'))
54
+ else:
55
+ result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
56
+ return "", result_text
57
+
58
+ result_text.append((input_text, answer))
59
+ print(result_text)
60
+ return "", result_text
61
+
62
+
63
+ DESCRIPTION = '''# <a href="https://github.com/THUDM/VisualGLM-6B">VisualGLM</a>'''
64
+
65
+ MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
66
+ MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
67
+
68
+ NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
69
+
70
+
71
+ def clear_fn(value):
72
+ return "", [("", "Hi, What do you want to know about this image?")], None
73
+
74
+ def clear_fn2(value):
75
+ return [("", "Hi, What do you want to know about this image?")]
76
+
77
+
78
+ def main(args):
79
+ gr.close_all()
80
+ global model, tokenizer
81
+ model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)
82
+
83
+ with gr.Blocks(css='style.css') as demo:
84
+ gr.Markdown(DESCRIPTION)
85
+ with gr.Row():
86
+ with gr.Column(scale=4.5):
87
+ with gr.Group():
88
+ input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
89
+ with gr.Row():
90
+ run_button = gr.Button('Generate')
91
+ clear_button = gr.Button('Clear')
92
+
93
+ image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
94
+ with gr.Row():
95
+ temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
96
+ top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
97
+ with gr.Group():
98
+ with gr.Row():
99
+ maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
100
+ with gr.Column(scale=5.5):
101
+ result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
102
+
103
+ gr.Markdown(NOTES)
104
+
105
+ print(gr.__version__)
106
+ run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
107
+ outputs=[input_text, result_text])
108
+ input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
109
+ outputs=[input_text, result_text])
110
+ clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
111
+ image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
112
+ image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
113
+
114
+ print(gr.__version__)
115
+
116
+ demo.queue(concurrency_count=10)
117
+ demo.launch(share=args.share)
118
+
119
+
120
+ if __name__ == '__main__':
121
+ import argparse
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
124
+ parser.add_argument("--share", action="store_true")
125
+ args = parser.parse_args()
126
+ args.share = "True"
127
+ main(args)