Update app.py
Browse files
app.py
CHANGED
@@ -1,75 +1,206 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
import json
|
4 |
-
|
5 |
-
|
6 |
-
from
|
7 |
-
from
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
import gradio as gr
|
4 |
+
import os
|
5 |
import json
|
6 |
+
import requests
|
7 |
+
import time
|
8 |
+
from concurrent.futures import ThreadPoolExecutor
|
9 |
+
from utils import is_chinese, process_image_without_resize, parse_response, templates_agent_cogagent, template_grounding_cogvlm, postprocess_text
|
10 |
+
|
11 |
+
DESCRIPTION = '''<h2 style='text-align: center'> <a href="https://github.com/THUDM/CogVLM"> CogVLM & CogAgent Chat Demo</a> </h2>'''
|
12 |
+
|
13 |
+
NOTES = 'This app is adapted from <a href="https://github.com/THUDM/CogVLM">https://github.com/THUDM/CogVLM</a>. It would be recommended to check out the repo if you want to see the detail of our model.\n\n该demo仅作为测试使用,不支持批量请求。如有大批量需求,欢迎联系[智谱AI](mailto:business@zhipuai.cn)。\n\n请注意该Demo目前仅支持英文,<a href="http://36.103.203.44:7861/">备用网页</a>支持中文。'
|
14 |
+
|
15 |
+
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.<br>Hint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
|
16 |
+
|
17 |
+
GROUNDING_NOTICE = 'Hint: When you check "Grounding", please use the <a href="https://github.com/THUDM/CogVLM/blob/main/utils/utils/template.py#L344">corresponding prompt</a> or the examples below.'
|
18 |
+
|
19 |
+
AGENT_NOTICE = 'Hint: When you check "CogAgent", please use the <a href="https://github.com/THUDM/CogVLM/blob/main/utils/utils/template.py#L761C1-L761C17">corresponding prompt</a> or the examples below.'
|
20 |
+
|
21 |
+
|
22 |
+
default_chatbox = [("", "Hi, What do you want to know about this image?")]
|
23 |
+
|
24 |
+
URL = os.environ.get("URL")
|
25 |
+
|
26 |
+
|
27 |
+
def make_request(URL, headers, data):
|
28 |
+
response = requests.request("POST", URL, headers=headers, data=data, timeout=(60, 100))
|
29 |
+
return response.json()
|
30 |
+
|
31 |
+
def post(
|
32 |
+
input_text,
|
33 |
+
temperature,
|
34 |
+
top_p,
|
35 |
+
top_k,
|
36 |
+
image_prompt,
|
37 |
+
result_previous,
|
38 |
+
hidden_image,
|
39 |
+
grounding,
|
40 |
+
cogagent,
|
41 |
+
grounding_template,
|
42 |
+
agent_template
|
43 |
+
):
|
44 |
+
result_text = [(ele[0], ele[1]) for ele in result_previous]
|
45 |
+
for i in range(len(result_text)-1, -1, -1):
|
46 |
+
if result_text[i][0] == "" or result_text[i][0] == None:
|
47 |
+
del result_text[i]
|
48 |
+
print(f"history {result_text}")
|
49 |
+
|
50 |
+
is_zh = is_chinese(input_text)
|
51 |
+
|
52 |
+
if image_prompt is None:
|
53 |
+
print("Image empty")
|
54 |
+
if is_zh:
|
55 |
+
result_text.append((input_text, '图片为空!请上传图片并重试。'))
|
56 |
+
else:
|
57 |
+
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
|
58 |
+
return input_text, result_text, hidden_image
|
59 |
+
elif input_text == "":
|
60 |
+
print("Text empty")
|
61 |
+
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
|
62 |
+
return "", result_text, hidden_image
|
63 |
+
|
64 |
+
headers = {
|
65 |
+
"Content-Type": "application/json; charset=UTF-8",
|
66 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.87 Safari/537.36",
|
67 |
+
}
|
68 |
+
if image_prompt:
|
69 |
+
pil_img, encoded_img, image_hash, image_path_grounding = process_image_without_resize(image_prompt)
|
70 |
+
print(f"image_hash:{image_hash}, hidden_image_hash:{hidden_image}")
|
71 |
+
|
72 |
+
if hidden_image is not None and image_hash != hidden_image:
|
73 |
+
print("image has been update")
|
74 |
+
result_text = []
|
75 |
+
hidden_image = image_hash
|
76 |
+
else:
|
77 |
+
encoded_img = None
|
78 |
+
|
79 |
+
model_use = "vlm_chat"
|
80 |
+
if not cogagent and grounding:
|
81 |
+
model_use = "vlm_grounding"
|
82 |
+
if grounding_template:
|
83 |
+
input_text = postprocess_text(grounding_template, input_text)
|
84 |
+
elif cogagent:
|
85 |
+
model_use = "agent_chat"
|
86 |
+
if agent_template is not None and agent_template != "do not use template":
|
87 |
+
input_text = postprocess_text(agent_template, input_text)
|
88 |
+
|
89 |
+
prompt = input_text
|
90 |
+
|
91 |
+
if grounding:
|
92 |
+
prompt += "(with grounding)"
|
93 |
+
|
94 |
+
print(f'request {model_use} model... with prompt {prompt}, grounding_template {grounding_template}, agent_template {agent_template}')
|
95 |
+
data = json.dumps({
|
96 |
+
'model_use': model_use,
|
97 |
+
'is_grounding': grounding,
|
98 |
+
'text': prompt,
|
99 |
+
'history': result_text,
|
100 |
+
'image': encoded_img,
|
101 |
+
'temperature': temperature,
|
102 |
+
'top_p': top_p,
|
103 |
+
'top_k': top_k,
|
104 |
+
'do_sample': True,
|
105 |
+
'max_new_tokens': 2048
|
106 |
+
})
|
107 |
+
try:
|
108 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
109 |
+
future = executor.submit(make_request, URL, headers, data)
|
110 |
+
# time.sleep(15)
|
111 |
+
response = future.result() # Blocks until the request is complete
|
112 |
+
# response = requests.request("POST", URL, headers=headers, data=data, timeout=(60, 100)).json()
|
113 |
+
except Exception as e:
|
114 |
+
print("error message", e)
|
115 |
+
if is_zh:
|
116 |
+
result_text.append((input_text, '超时!请稍等几分钟再重试。'))
|
117 |
+
else:
|
118 |
+
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
|
119 |
+
return "", result_text, hidden_image
|
120 |
+
print('request done...')
|
121 |
+
# response = {'result':input_text}
|
122 |
+
|
123 |
+
answer = str(response['result'])
|
124 |
+
if grounding:
|
125 |
+
parse_response(pil_img, answer, image_path_grounding)
|
126 |
+
new_answer = answer.replace(input_text, "")
|
127 |
+
result_text.append((input_text, new_answer))
|
128 |
+
result_text.append((None, (image_path_grounding,)))
|
129 |
+
else:
|
130 |
+
result_text.append((input_text, answer))
|
131 |
+
print(result_text)
|
132 |
+
print('finished')
|
133 |
+
return "", result_text, hidden_image
|
134 |
+
|
135 |
+
|
136 |
+
def clear_fn(value):
|
137 |
+
return "", default_chatbox, None
|
138 |
+
|
139 |
+
def clear_fn2(value):
|
140 |
+
return default_chatbox
|
141 |
+
|
142 |
+
|
143 |
+
def main():
|
144 |
+
gr.close_all()
|
145 |
+
examples = []
|
146 |
+
with open("./examples/example_inputs.jsonl") as f:
|
147 |
+
for line in f:
|
148 |
+
data = json.loads(line)
|
149 |
+
examples.append(data)
|
150 |
+
|
151 |
+
|
152 |
+
with gr.Blocks(css='style.css') as demo:
|
153 |
+
|
154 |
+
gr.Markdown(DESCRIPTION)
|
155 |
+
gr.Markdown(NOTES)
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
with gr.Column(scale=4.5):
|
159 |
+
with gr.Group():
|
160 |
+
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
|
161 |
+
with gr.Row():
|
162 |
+
run_button = gr.Button('Generate')
|
163 |
+
clear_button = gr.Button('Clear')
|
164 |
+
|
165 |
+
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
|
166 |
+
with gr.Row():
|
167 |
+
grounding = gr.Checkbox(label="Grounding")
|
168 |
+
cogagent = gr.Checkbox(label="CogAgent")
|
169 |
+
with gr.Row():
|
170 |
+
# grounding_notice = gr.Markdown(GROUNDING_NOTICE)
|
171 |
+
grounding_template = gr.Dropdown(choices=template_grounding_cogvlm, label="Grounding Template", value=template_grounding_cogvlm[0])
|
172 |
+
# agent_notice = gr.Markdown(AGENT_NOTICE)
|
173 |
+
agent_template = gr.Dropdown(choices=templates_agent_cogagent, label="Agent Template", value=templates_agent_cogagent[0])
|
174 |
+
|
175 |
+
with gr.Row():
|
176 |
+
temperature = gr.Slider(maximum=1, value=0.9, minimum=0, label='Temperature')
|
177 |
+
top_p = gr.Slider(maximum=1, value=0.8, minimum=0, label='Top P')
|
178 |
+
top_k = gr.Slider(maximum=50, value=5, minimum=1, step=1, label='Top K')
|
179 |
+
|
180 |
+
with gr.Column(scale=5.5):
|
181 |
+
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")], height=550)
|
182 |
+
hidden_image_hash = gr.Textbox(visible=False)
|
183 |
+
|
184 |
+
gr_examples = gr.Examples(examples=[[example["text"], example["image"], example["grounding"], example["cogagent"]] for example in examples],
|
185 |
+
inputs=[input_text, image_prompt, grounding, cogagent],
|
186 |
+
label="Example Inputs (Click to insert an examplet into the input box)",
|
187 |
+
examples_per_page=6)
|
188 |
+
|
189 |
+
gr.Markdown(MAINTENANCE_NOTICE1)
|
190 |
+
|
191 |
+
print(gr.__version__)
|
192 |
+
run_button.click(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, grounding, cogagent, grounding_template, agent_template],
|
193 |
+
outputs=[input_text, result_text, hidden_image_hash])
|
194 |
+
input_text.submit(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, grounding, cogagent, grounding_template, agent_template],
|
195 |
+
outputs=[input_text, result_text, hidden_image_hash])
|
196 |
+
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
|
197 |
+
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
|
198 |
+
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
|
199 |
+
|
200 |
+
print(gr.__version__)
|
201 |
+
|
202 |
+
demo.queue(concurrency_count=10)
|
203 |
+
demo.launch()
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
main()
|