Spaces:
Sleeping
Sleeping
teowu
commited on
Commit
•
8132ec4
1
Parent(s):
0ebd86a
Add IQA function!
Browse files- app.py +109 -3
- model_worker.py +84 -0
app.py
CHANGED
@@ -113,6 +113,7 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
|
|
113 |
state.append_message(state.roles[0], text)
|
114 |
state.append_message(state.roles[1], None)
|
115 |
state.skip_next = False
|
|
|
116 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
117 |
|
118 |
|
@@ -201,6 +202,92 @@ def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
|
|
201 |
"ip": request.client.host,
|
202 |
}
|
203 |
fout.write(json.dumps(data) + "\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
|
206 |
title_markdown = ("""
|
@@ -208,7 +295,7 @@ title_markdown = ("""
|
|
208 |
|
209 |
<h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
|
210 |
|
211 |
-
<h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
|
212 |
|
213 |
<div align="center">
|
214 |
<div style="display:flex; gap: 0.25rem;" align="center">
|
@@ -218,10 +305,15 @@ title_markdown = ("""
|
|
218 |
</div>
|
219 |
</div>
|
220 |
|
|
|
|
|
|
|
|
|
221 |
""")
|
222 |
|
223 |
|
224 |
tos_markdown = ("""
|
|
|
225 |
### Terms of use
|
226 |
By using this service, users are required to agree to the following terms:
|
227 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
@@ -244,7 +336,7 @@ block_css = """
|
|
244 |
"""
|
245 |
|
246 |
def build_demo(embed_mode):
|
247 |
-
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
248 |
with gr.Blocks(title="Q-Instruct-on-mPLUG-Owl-2", theme=gr.themes.Default(), css=block_css) as demo:
|
249 |
state = gr.State()
|
250 |
|
@@ -271,12 +363,14 @@ def build_demo(embed_mode):
|
|
271 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
272 |
|
273 |
with gr.Column(scale=8):
|
274 |
-
chatbot = gr.Chatbot(elem_id="Chatbot", label="Q-Instruct-Chatbot", height=
|
275 |
with gr.Row():
|
276 |
with gr.Column(scale=8):
|
277 |
textbox.render()
|
278 |
with gr.Column(scale=1, min_width=50):
|
279 |
submit_btn = gr.Button(value="Send", variant="primary")
|
|
|
|
|
280 |
with gr.Row(elem_id="buttons") as button_row:
|
281 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
282 |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
@@ -356,6 +450,18 @@ def build_demo(embed_mode):
|
|
356 |
[state, temperature, top_p, max_output_tokens],
|
357 |
[state, chatbot] + btn_list
|
358 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
demo.load(
|
361 |
load_demo,
|
|
|
113 |
state.append_message(state.roles[0], text)
|
114 |
state.append_message(state.roles[1], None)
|
115 |
state.skip_next = False
|
116 |
+
print(text)
|
117 |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
118 |
|
119 |
|
|
|
202 |
"ip": request.client.host,
|
203 |
}
|
204 |
fout.write(json.dumps(data) + "\n")
|
205 |
+
|
206 |
+
def http_bot_modified(state, request: gr.Request):
|
207 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
208 |
+
start_tstamp = time.time()
|
209 |
+
if state.skip_next:
|
210 |
+
# This generate call is skipped due to invalid inputs
|
211 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
212 |
+
return
|
213 |
+
|
214 |
+
print(state.messages[-2][1])
|
215 |
+
state.messages[-2][1] = ('<|image|>Rate the quality of the image.',) + state.messages[-2][1][1:]
|
216 |
+
print(state.messages[-2][1])
|
217 |
+
|
218 |
+
if len(state.messages) == state.offset + 2:
|
219 |
+
# First round of conversation
|
220 |
+
template_name = "mplug_owl2"
|
221 |
+
new_state = conv_templates[template_name].copy()
|
222 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
223 |
+
new_state.append_message(new_state.roles[1], None)
|
224 |
+
state = new_state
|
225 |
+
|
226 |
+
# Construct prompt
|
227 |
+
prompt = state.get_prompt()
|
228 |
+
|
229 |
+
all_images = state.get_images(return_pil=True)
|
230 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
231 |
+
for image, hash in zip(all_images, all_image_hash):
|
232 |
+
t = datetime.datetime.now()
|
233 |
+
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
234 |
+
if not os.path.isfile(filename):
|
235 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
236 |
+
image.save(filename)
|
237 |
+
|
238 |
+
# Make requests
|
239 |
+
pload = {
|
240 |
+
"prompt": prompt,
|
241 |
+
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
242 |
+
}
|
243 |
+
logger.info(f"==== request ====\n{pload}")
|
244 |
+
|
245 |
+
pload['images'] = state.get_images()
|
246 |
+
|
247 |
+
state.messages[-1][-1] = "▌"
|
248 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
249 |
+
|
250 |
+
try:
|
251 |
+
# Stream output
|
252 |
+
# response = requests.post(worker_addr + "/worker_generate_stream",
|
253 |
+
# headers=headers, json=pload, stream=True, timeout=10)
|
254 |
+
# for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
255 |
+
response = model.predict_stream_gate(pload)
|
256 |
+
for chunk in response:
|
257 |
+
if chunk:
|
258 |
+
data = json.loads(chunk.decode())
|
259 |
+
if data["error_code"] == 0:
|
260 |
+
output = data["text"][len(prompt):].strip()
|
261 |
+
state.messages[-1][-1] = output + "▌"
|
262 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
263 |
+
else:
|
264 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
265 |
+
state.messages[-1][-1] = output
|
266 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
267 |
+
return
|
268 |
+
time.sleep(0.03)
|
269 |
+
except requests.exceptions.RequestException as e:
|
270 |
+
state.messages[-1][-1] = server_error_msg
|
271 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
272 |
+
return
|
273 |
+
|
274 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
275 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
276 |
+
|
277 |
+
finish_tstamp = time.time()
|
278 |
+
logger.info(f"{output}")
|
279 |
+
|
280 |
+
with open(get_conv_log_filename(), "a") as fout:
|
281 |
+
data = {
|
282 |
+
"tstamp": round(finish_tstamp, 4),
|
283 |
+
"type": "chat",
|
284 |
+
"start": round(start_tstamp, 4),
|
285 |
+
"finish": round(start_tstamp, 4),
|
286 |
+
"state": state.dict(),
|
287 |
+
"images": all_image_hash,
|
288 |
+
"ip": request.client.host,
|
289 |
+
}
|
290 |
+
fout.write(json.dumps(data) + "\n")
|
291 |
|
292 |
|
293 |
title_markdown = ("""
|
|
|
295 |
|
296 |
<h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
|
297 |
|
298 |
+
<h5 align="center"> If you like our project, please give us a star ✨ on [Github](https://github.com/Q-Future/Q-Instruct) for latest update. </h2>
|
299 |
|
300 |
<div align="center">
|
301 |
<div style="display:flex; gap: 0.25rem;" align="center">
|
|
|
305 |
</div>
|
306 |
</div>
|
307 |
|
308 |
+
### Special Usage: *Rate!*
|
309 |
+
To get an image quality score, just upload a new image and click the **Rate!** button. This will redirect to a special method that return a quality score in [0,1].
|
310 |
+
Always make sure that there is some text in the textbox before you click the **Rate!** button.
|
311 |
+
|
312 |
""")
|
313 |
|
314 |
|
315 |
tos_markdown = ("""
|
316 |
+
|
317 |
### Terms of use
|
318 |
By using this service, users are required to agree to the following terms:
|
319 |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
|
|
336 |
"""
|
337 |
|
338 |
def build_demo(embed_mode):
|
339 |
+
textbox = gr.Textbox(show_label=False, value="Rate the quality of the image.", placeholder="Enter text and press ENTER", container=False)
|
340 |
with gr.Blocks(title="Q-Instruct-on-mPLUG-Owl-2", theme=gr.themes.Default(), css=block_css) as demo:
|
341 |
state = gr.State()
|
342 |
|
|
|
363 |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
364 |
|
365 |
with gr.Column(scale=8):
|
366 |
+
chatbot = gr.Chatbot(elem_id="Chatbot", label="Q-Instruct-Chatbot", height=750)
|
367 |
with gr.Row():
|
368 |
with gr.Column(scale=8):
|
369 |
textbox.render()
|
370 |
with gr.Column(scale=1, min_width=50):
|
371 |
submit_btn = gr.Button(value="Send", variant="primary")
|
372 |
+
with gr.Column(scale=1, min_width=50):
|
373 |
+
rate_btn = gr.Button(value="Rate!", variant="primary")
|
374 |
with gr.Row(elem_id="buttons") as button_row:
|
375 |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
376 |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
|
|
450 |
[state, temperature, top_p, max_output_tokens],
|
451 |
[state, chatbot] + btn_list
|
452 |
)
|
453 |
+
|
454 |
+
rate_btn.click(
|
455 |
+
add_text,
|
456 |
+
[state, textbox, imagebox, image_process_mode],
|
457 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
458 |
+
queue=False,
|
459 |
+
concurrency_limit=10,
|
460 |
+
).then(
|
461 |
+
http_bot_modified,
|
462 |
+
[state],
|
463 |
+
[state, chatbot] + btn_list
|
464 |
+
)
|
465 |
|
466 |
demo.load(
|
467 |
load_demo,
|
model_worker.py
CHANGED
@@ -45,7 +45,65 @@ class ModelWorker:
|
|
45 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
46 |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
47 |
self.is_multimodal = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
@torch.inference_mode()
|
50 |
def generate_stream(self, params):
|
51 |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
@@ -115,6 +173,32 @@ class ModelWorker:
|
|
115 |
if generated_text.endswith(stop_str):
|
116 |
generated_text = generated_text[:-len(stop_str)]
|
117 |
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
def generate_stream_gate(self, params):
|
120 |
try:
|
|
|
45 |
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
46 |
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
47 |
self.is_multimodal = True
|
48 |
+
|
49 |
+
@torch.inference_mode()
|
50 |
+
def predict_stream(self, params):
|
51 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
52 |
+
|
53 |
+
prompt = params["prompt"] + "The quality of the image is"
|
54 |
+
ori_prompt = prompt
|
55 |
+
images = params.get("images", None)
|
56 |
+
num_image_tokens = 0
|
57 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
58 |
+
if len(images) > 0:
|
59 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
60 |
+
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
61 |
+
|
62 |
+
images = [load_image_from_base64(image) for image in images]
|
63 |
+
images = process_images(images, image_processor, model.config)
|
64 |
+
|
65 |
+
if type(images) is list:
|
66 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
67 |
+
else:
|
68 |
+
images = images.to(self.model.device, dtype=torch.float16)
|
69 |
+
|
70 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
71 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
72 |
|
73 |
+
num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
|
74 |
+
else:
|
75 |
+
images = None
|
76 |
+
image_args = {"images": images}
|
77 |
+
else:
|
78 |
+
images = None
|
79 |
+
image_args = {}
|
80 |
+
|
81 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
82 |
+
|
83 |
+
logits = model.forward(
|
84 |
+
input_ids=input_ids,
|
85 |
+
use_cache=True,
|
86 |
+
**image_args).logits[0,-1]
|
87 |
+
|
88 |
+
print(logits.shape)
|
89 |
+
|
90 |
+
softmax_logits = torch.softmax(logits[[1781,6588,6460]], 0)
|
91 |
+
|
92 |
+
print(tokenizer(["good", "average", "poor"]))
|
93 |
+
fake_streamer = []
|
94 |
+
for id_, word in enumerate(["good", "average", "poor"]):
|
95 |
+
stream_ = f"Probability of {word} quality: {softmax_logits[id_].item():.4f};\n"
|
96 |
+
fake_streamer.append(stream_)
|
97 |
+
|
98 |
+
quality_score = 0.5 * softmax_logits[1] + softmax_logits[0]
|
99 |
+
stream_ = f"Quality score: {quality_score:.4f} (range [0,1])."
|
100 |
+
fake_streamer.append(stream_)
|
101 |
+
|
102 |
+
generated_text = ori_prompt.replace("The quality of the image is", "")
|
103 |
+
for new_text in fake_streamer:
|
104 |
+
generated_text += new_text
|
105 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
106 |
+
|
107 |
@torch.inference_mode()
|
108 |
def generate_stream(self, params):
|
109 |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
|
|
173 |
if generated_text.endswith(stop_str):
|
174 |
generated_text = generated_text[:-len(stop_str)]
|
175 |
yield json.dumps({"text": generated_text, "error_code": 0}).encode()
|
176 |
+
|
177 |
+
def predict_stream_gate(self, params):
|
178 |
+
try:
|
179 |
+
for x in self.predict_stream(params):
|
180 |
+
yield x
|
181 |
+
except ValueError as e:
|
182 |
+
print("Caught ValueError:", e)
|
183 |
+
ret = {
|
184 |
+
"text": server_error_msg,
|
185 |
+
"error_code": 1,
|
186 |
+
}
|
187 |
+
yield json.dumps(ret).encode()
|
188 |
+
except torch.cuda.CudaError as e:
|
189 |
+
print("Caught torch.cuda.CudaError:", e)
|
190 |
+
ret = {
|
191 |
+
"text": server_error_msg,
|
192 |
+
"error_code": 1,
|
193 |
+
}
|
194 |
+
yield json.dumps(ret).encode()
|
195 |
+
except Exception as e:
|
196 |
+
print("Caught Unknown Error", e)
|
197 |
+
ret = {
|
198 |
+
"text": server_error_msg,
|
199 |
+
"error_code": 1,
|
200 |
+
}
|
201 |
+
yield json.dumps(ret).encode()
|
202 |
|
203 |
def generate_stream_gate(self, params):
|
204 |
try:
|