teowu commited on
Commit
8132ec4
1 Parent(s): 0ebd86a

Add IQA function!

Browse files
Files changed (2) hide show
  1. app.py +109 -3
  2. model_worker.py +84 -0
app.py CHANGED
@@ -113,6 +113,7 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
113
  state.append_message(state.roles[0], text)
114
  state.append_message(state.roles[1], None)
115
  state.skip_next = False
 
116
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
117
 
118
 
@@ -201,6 +202,92 @@ def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
201
  "ip": request.client.host,
202
  }
203
  fout.write(json.dumps(data) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  title_markdown = ("""
@@ -208,7 +295,7 @@ title_markdown = ("""
208
 
209
  <h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
210
 
211
- <h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
212
 
213
  <div align="center">
214
  <div style="display:flex; gap: 0.25rem;" align="center">
@@ -218,10 +305,15 @@ title_markdown = ("""
218
  </div>
219
  </div>
220
 
 
 
 
 
221
  """)
222
 
223
 
224
  tos_markdown = ("""
 
225
  ### Terms of use
226
  By using this service, users are required to agree to the following terms:
227
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
@@ -244,7 +336,7 @@ block_css = """
244
  """
245
 
246
  def build_demo(embed_mode):
247
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
248
  with gr.Blocks(title="Q-Instruct-on-mPLUG-Owl-2", theme=gr.themes.Default(), css=block_css) as demo:
249
  state = gr.State()
250
 
@@ -271,12 +363,14 @@ def build_demo(embed_mode):
271
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
272
 
273
  with gr.Column(scale=8):
274
- chatbot = gr.Chatbot(elem_id="Chatbot", label="Q-Instruct-Chatbot", height=600)
275
  with gr.Row():
276
  with gr.Column(scale=8):
277
  textbox.render()
278
  with gr.Column(scale=1, min_width=50):
279
  submit_btn = gr.Button(value="Send", variant="primary")
 
 
280
  with gr.Row(elem_id="buttons") as button_row:
281
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
282
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
@@ -356,6 +450,18 @@ def build_demo(embed_mode):
356
  [state, temperature, top_p, max_output_tokens],
357
  [state, chatbot] + btn_list
358
  )
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  demo.load(
361
  load_demo,
 
113
  state.append_message(state.roles[0], text)
114
  state.append_message(state.roles[1], None)
115
  state.skip_next = False
116
+ print(text)
117
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
118
 
119
 
 
202
  "ip": request.client.host,
203
  }
204
  fout.write(json.dumps(data) + "\n")
205
+
206
+ def http_bot_modified(state, request: gr.Request):
207
+ logger.info(f"http_bot. ip: {request.client.host}")
208
+ start_tstamp = time.time()
209
+ if state.skip_next:
210
+ # This generate call is skipped due to invalid inputs
211
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
212
+ return
213
+
214
+ print(state.messages[-2][1])
215
+ state.messages[-2][1] = ('<|image|>Rate the quality of the image.',) + state.messages[-2][1][1:]
216
+ print(state.messages[-2][1])
217
+
218
+ if len(state.messages) == state.offset + 2:
219
+ # First round of conversation
220
+ template_name = "mplug_owl2"
221
+ new_state = conv_templates[template_name].copy()
222
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
223
+ new_state.append_message(new_state.roles[1], None)
224
+ state = new_state
225
+
226
+ # Construct prompt
227
+ prompt = state.get_prompt()
228
+
229
+ all_images = state.get_images(return_pil=True)
230
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
231
+ for image, hash in zip(all_images, all_image_hash):
232
+ t = datetime.datetime.now()
233
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
234
+ if not os.path.isfile(filename):
235
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
236
+ image.save(filename)
237
+
238
+ # Make requests
239
+ pload = {
240
+ "prompt": prompt,
241
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
242
+ }
243
+ logger.info(f"==== request ====\n{pload}")
244
+
245
+ pload['images'] = state.get_images()
246
+
247
+ state.messages[-1][-1] = "▌"
248
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
249
+
250
+ try:
251
+ # Stream output
252
+ # response = requests.post(worker_addr + "/worker_generate_stream",
253
+ # headers=headers, json=pload, stream=True, timeout=10)
254
+ # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
255
+ response = model.predict_stream_gate(pload)
256
+ for chunk in response:
257
+ if chunk:
258
+ data = json.loads(chunk.decode())
259
+ if data["error_code"] == 0:
260
+ output = data["text"][len(prompt):].strip()
261
+ state.messages[-1][-1] = output + "▌"
262
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
263
+ else:
264
+ output = data["text"] + f" (error_code: {data['error_code']})"
265
+ state.messages[-1][-1] = output
266
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
267
+ return
268
+ time.sleep(0.03)
269
+ except requests.exceptions.RequestException as e:
270
+ state.messages[-1][-1] = server_error_msg
271
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
272
+ return
273
+
274
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
275
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
276
+
277
+ finish_tstamp = time.time()
278
+ logger.info(f"{output}")
279
+
280
+ with open(get_conv_log_filename(), "a") as fout:
281
+ data = {
282
+ "tstamp": round(finish_tstamp, 4),
283
+ "type": "chat",
284
+ "start": round(start_tstamp, 4),
285
+ "finish": round(start_tstamp, 4),
286
+ "state": state.dict(),
287
+ "images": all_image_hash,
288
+ "ip": request.client.host,
289
+ }
290
+ fout.write(json.dumps(data) + "\n")
291
 
292
 
293
  title_markdown = ("""
 
295
 
296
  <h2 align="center">Q-Instruct: Improving Low-level Visual Abilities for Multi-modality Foundation Models</h2>
297
 
298
+ <h5 align="center"> If you like our project, please give us a star ✨ on [Github](https://github.com/Q-Future/Q-Instruct) for latest update. </h2>
299
 
300
  <div align="center">
301
  <div style="display:flex; gap: 0.25rem;" align="center">
 
305
  </div>
306
  </div>
307
 
308
+ ### Special Usage: *Rate!*
309
+ To get an image quality score, just upload a new image and click the **Rate!** button. This will redirect to a special method that return a quality score in [0,1].
310
+ Always make sure that there is some text in the textbox before you click the **Rate!** button.
311
+
312
  """)
313
 
314
 
315
  tos_markdown = ("""
316
+
317
  ### Terms of use
318
  By using this service, users are required to agree to the following terms:
319
  The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
 
336
  """
337
 
338
  def build_demo(embed_mode):
339
+ textbox = gr.Textbox(show_label=False, value="Rate the quality of the image.", placeholder="Enter text and press ENTER", container=False)
340
  with gr.Blocks(title="Q-Instruct-on-mPLUG-Owl-2", theme=gr.themes.Default(), css=block_css) as demo:
341
  state = gr.State()
342
 
 
363
  max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
364
 
365
  with gr.Column(scale=8):
366
+ chatbot = gr.Chatbot(elem_id="Chatbot", label="Q-Instruct-Chatbot", height=750)
367
  with gr.Row():
368
  with gr.Column(scale=8):
369
  textbox.render()
370
  with gr.Column(scale=1, min_width=50):
371
  submit_btn = gr.Button(value="Send", variant="primary")
372
+ with gr.Column(scale=1, min_width=50):
373
+ rate_btn = gr.Button(value="Rate!", variant="primary")
374
  with gr.Row(elem_id="buttons") as button_row:
375
  upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
376
  downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
 
450
  [state, temperature, top_p, max_output_tokens],
451
  [state, chatbot] + btn_list
452
  )
453
+
454
+ rate_btn.click(
455
+ add_text,
456
+ [state, textbox, imagebox, image_process_mode],
457
+ [state, chatbot, textbox, imagebox] + btn_list,
458
+ queue=False,
459
+ concurrency_limit=10,
460
+ ).then(
461
+ http_bot_modified,
462
+ [state],
463
+ [state, chatbot] + btn_list
464
+ )
465
 
466
  demo.load(
467
  load_demo,
model_worker.py CHANGED
@@ -45,7 +45,65 @@ class ModelWorker:
45
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
46
  model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
47
  self.is_multimodal = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @torch.inference_mode()
50
  def generate_stream(self, params):
51
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
@@ -115,6 +173,32 @@ class ModelWorker:
115
  if generated_text.endswith(stop_str):
116
  generated_text = generated_text[:-len(stop_str)]
117
  yield json.dumps({"text": generated_text, "error_code": 0}).encode()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def generate_stream_gate(self, params):
120
  try:
 
45
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
46
  model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
47
  self.is_multimodal = True
48
+
49
+ @torch.inference_mode()
50
+ def predict_stream(self, params):
51
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
52
+
53
+ prompt = params["prompt"] + "The quality of the image is"
54
+ ori_prompt = prompt
55
+ images = params.get("images", None)
56
+ num_image_tokens = 0
57
+ if images is not None and len(images) > 0 and self.is_multimodal:
58
+ if len(images) > 0:
59
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
60
+ raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
61
+
62
+ images = [load_image_from_base64(image) for image in images]
63
+ images = process_images(images, image_processor, model.config)
64
+
65
+ if type(images) is list:
66
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
67
+ else:
68
+ images = images.to(self.model.device, dtype=torch.float16)
69
+
70
+ replace_token = DEFAULT_IMAGE_TOKEN
71
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
72
 
73
+ num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
74
+ else:
75
+ images = None
76
+ image_args = {"images": images}
77
+ else:
78
+ images = None
79
+ image_args = {}
80
+
81
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
82
+
83
+ logits = model.forward(
84
+ input_ids=input_ids,
85
+ use_cache=True,
86
+ **image_args).logits[0,-1]
87
+
88
+ print(logits.shape)
89
+
90
+ softmax_logits = torch.softmax(logits[[1781,6588,6460]], 0)
91
+
92
+ print(tokenizer(["good", "average", "poor"]))
93
+ fake_streamer = []
94
+ for id_, word in enumerate(["good", "average", "poor"]):
95
+ stream_ = f"Probability of {word} quality: {softmax_logits[id_].item():.4f};\n"
96
+ fake_streamer.append(stream_)
97
+
98
+ quality_score = 0.5 * softmax_logits[1] + softmax_logits[0]
99
+ stream_ = f"Quality score: {quality_score:.4f} (range [0,1])."
100
+ fake_streamer.append(stream_)
101
+
102
+ generated_text = ori_prompt.replace("The quality of the image is", "")
103
+ for new_text in fake_streamer:
104
+ generated_text += new_text
105
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
106
+
107
  @torch.inference_mode()
108
  def generate_stream(self, params):
109
  tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
 
173
  if generated_text.endswith(stop_str):
174
  generated_text = generated_text[:-len(stop_str)]
175
  yield json.dumps({"text": generated_text, "error_code": 0}).encode()
176
+
177
+ def predict_stream_gate(self, params):
178
+ try:
179
+ for x in self.predict_stream(params):
180
+ yield x
181
+ except ValueError as e:
182
+ print("Caught ValueError:", e)
183
+ ret = {
184
+ "text": server_error_msg,
185
+ "error_code": 1,
186
+ }
187
+ yield json.dumps(ret).encode()
188
+ except torch.cuda.CudaError as e:
189
+ print("Caught torch.cuda.CudaError:", e)
190
+ ret = {
191
+ "text": server_error_msg,
192
+ "error_code": 1,
193
+ }
194
+ yield json.dumps(ret).encode()
195
+ except Exception as e:
196
+ print("Caught Unknown Error", e)
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 1,
200
+ }
201
+ yield json.dumps(ret).encode()
202
 
203
  def generate_stream_gate(self, params):
204
  try: