gulixin0922 commited on
Commit
aeb3796
·
verified ·
1 Parent(s): 930ec51

update backend api

Browse files
Files changed (1) hide show
  1. app.py +27 -184
app.py CHANGED
@@ -47,40 +47,6 @@ def write2file(path, content):
47
  fout.write(content)
48
 
49
 
50
- def sort_models(models):
51
- def custom_sort_key(model_name):
52
- # InternVL-Chat-V1-5 should be the first item
53
- if model_name == "InternVL2-Pro":
54
- return (2, model_name) # 2 indicates highest precedence
55
- elif model_name.startswith("InternVL2-8B"):
56
- return (1, model_name) # 0 indicates highest precedence
57
- else:
58
- return (0, model_name) # 0 indicates normal order
59
-
60
- models.sort(key=custom_sort_key, reverse=True)
61
- # try: # We have five InternVL-Chat-V1-5 models, randomly choose one to be the first
62
- # first_three = models[:4]
63
- # random.shuffle(first_three)
64
- # models[:4] = first_three
65
- # except:
66
- # pass
67
- return models
68
-
69
-
70
- def get_model_list():
71
- logger.info(f"Call `get_model_list`")
72
- ret = requests.post(args.controller_url + "/refresh_all_workers")
73
- logger.info(f"status_code from `get_model_list`: {ret.status_code}")
74
- assert ret.status_code == 200
75
- ret = requests.post(args.controller_url + "/list_models")
76
- logger.info(f"status_code from `list_models`: {ret.status_code}")
77
- models = ret.json()["models"]
78
- models = sort_models(models)
79
-
80
- logger.info(f"Models (from {args.controller_url}): {models}")
81
- return models
82
-
83
-
84
  get_window_url_params = """
85
  function() {
86
  const params = new URLSearchParams(window.location.search);
@@ -154,48 +120,6 @@ def find_bounding_boxes(state, response):
154
  return returned_image if len(matches) > 0 else None
155
 
156
 
157
- def query_image_generation(response, sd_worker_url, timeout=15):
158
- if not sd_worker_url:
159
- return None
160
- sd_worker_url = f"{sd_worker_url}/generate_image/"
161
- pattern = r"```drawing-instruction\n(.*?)\n```"
162
- match = re.search(pattern, response, re.DOTALL)
163
- if match:
164
- payload = {"caption": match.group(1)}
165
- print("drawing-instruction:", payload)
166
- response = requests.post(sd_worker_url, json=payload, timeout=timeout)
167
- response.raise_for_status() # 检查HTTP请求是否成功
168
- image = Image.open(BytesIO(response.content))
169
- return image
170
- else:
171
- return None
172
-
173
-
174
- def load_demo(url_params, request: gr.Request = None):
175
- if not request:
176
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
177
-
178
- dropdown_update = gr.Dropdown(visible=True)
179
- if "model" in url_params:
180
- model = url_params["model"]
181
- if model in models:
182
- dropdown_update = gr.Dropdown(value=model, visible=True)
183
-
184
- state = init_state()
185
- return state, dropdown_update
186
-
187
-
188
- def load_demo_refresh_model_list(request: gr.Request = None):
189
- if not request:
190
- logger.info(f"load_demo. ip: {request.client.host}")
191
- models = get_model_list()
192
- state = init_state()
193
- dropdown_update = gr.Dropdown(
194
- choices=models, value=models[0] if len(models) > 0 else ""
195
- )
196
- return state, dropdown_update
197
-
198
-
199
  def vote_last_response(state, liked, model_selector, request: gr.Request):
200
  conv_data = {
201
  "tstamp": round(time.time(), 4),
@@ -249,7 +173,7 @@ def flag_last_response(state, model_selector, request: gr.Request):
249
  def regenerate(state, image_process_mode, request: gr.Request):
250
  logger.info(f"regenerate. ip: {request.client.host}")
251
  # state.messages[-1][-1] = None
252
- state.update_message(Conversation.ASSISTANT, None, -1)
253
  prev_human_msg = state.messages[-2]
254
  if type(prev_human_msg[1]) in (tuple, list):
255
  prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
@@ -327,16 +251,11 @@ def http_bot(
327
  ) + (no_change_btn,) * 5
328
  return
329
 
330
- # Query worker address
331
- controller_url = args.controller_url
332
- ret = requests.post(
333
- controller_url + "/get_worker_address", json={"model": model_name}
334
- )
335
- worker_addr = ret.json()["address"]
336
- if worker_addr.startswith("http://0.0.0.0") and args.worker_ip:
337
- worker_addr = worker_addr.replace("0.0.0.0", args.worker_ip)
338
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
339
 
 
340
  # No available worker
341
  if worker_addr == "":
342
  # state.messages[-1][-1] = server_error_msg
@@ -359,20 +278,14 @@ def http_bot(
359
  # Make requests
360
  pload = {
361
  "model": model_name,
362
- "prompt": state.get_prompt(),
363
  "temperature": float(temperature),
364
  "top_p": float(top_p),
365
- "max_new_tokens": max_new_tokens,
366
- "max_input_tiles": max_input_tiles,
367
- # "bbox_threshold": bbox_threshold,
368
- # "mask_threshold": mask_threshold,
369
  "repetition_penalty": repetition_penalty,
370
- "images": f"List of {len(all_images)} images: {all_image_paths}",
371
  }
372
  logger.info(f"==== request ====\n{pload}")
373
- pload.pop("images")
374
- pload["prompt"] = state.get_prompt(inlude_image=True)
375
- state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
376
  yield (
377
  state,
378
  state.to_gradio_chatbot(),
@@ -381,50 +294,25 @@ def http_bot(
381
 
382
  try:
383
  # Stream output
384
- response = requests.post(
385
- worker_addr + "/worker_generate_stream",
386
- headers=headers,
387
- json=pload,
388
- stream=True,
389
- timeout=20,
390
- )
391
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
392
  if chunk:
393
- data = json.loads(chunk.decode())
394
- if data["error_code"] == 0:
395
- if "text" in data:
396
- output = data["text"].strip()
397
- output += state.streaming_placeholder
398
-
399
- image = None
400
- if "image" in data:
401
- image = load_image_from_base64(data["image"])
402
- _ = state.save_image(image)
403
-
404
- state.update_message(Conversation.ASSISTANT, output, image)
405
- yield (
406
- state,
407
- state.to_gradio_chatbot(),
408
- gr.MultimodalTextbox(interactive=False),
409
- ) + (disable_btn,) * 5
410
- else:
411
- output = (
412
- f"**{data['text']}**" + f" (error_code: {data['error_code']})"
413
- )
414
-
415
- state.update_message(Conversation.ASSISTANT, output, None)
416
- yield (
417
- state,
418
- state.to_gradio_chatbot(),
419
- gr.MultimodalTextbox(interactive=True),
420
- ) + (
421
- disable_btn,
422
- disable_btn,
423
- disable_btn,
424
- enable_btn,
425
- enable_btn,
426
- )
427
- return
428
  except requests.exceptions.RequestException as e:
429
  state.update_message(Conversation.ASSISTANT, server_error_msg, None)
430
  yield (
@@ -445,12 +333,6 @@ def http_bot(
445
  returned_image = find_bounding_boxes(state, ai_response)
446
  returned_image = [returned_image] if returned_image else []
447
  state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
448
- if "```drawing-instruction" in ai_response:
449
- returned_image = query_image_generation(
450
- ai_response, sd_worker_url=sd_worker_url
451
- )
452
- returned_image = [returned_image] if returned_image else []
453
- state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
454
 
455
  state.end_of_current_turn()
456
 
@@ -577,7 +459,7 @@ def build_demo(embed_mode):
577
  theme=gr.themes.Default(),
578
  css=block_css,
579
  ) as demo:
580
- models = get_model_list()
581
  state = gr.State()
582
 
583
  if not embed_mode:
@@ -797,27 +679,6 @@ def build_demo(embed_mode):
797
  [state, chatbot, textbox] + btn_list,
798
  )
799
 
800
- # NOTE: The following code will be not triggered when deployed on HF space.
801
- # It's very strange. I don't know why.
802
- """
803
- if args.model_list_mode == "once":
804
- demo.load(
805
- load_demo,
806
- [url_params],
807
- [state, model_selector],
808
- js=js,
809
- )
810
- elif args.model_list_mode == "reload":
811
- demo.load(
812
- load_demo_refresh_model_list,
813
- None,
814
- [state, model_selector],
815
- js=js,
816
- )
817
- else:
818
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
819
- """
820
-
821
  return demo
822
 
823
 
@@ -825,31 +686,13 @@ if __name__ == "__main__":
825
  parser = argparse.ArgumentParser()
826
  parser.add_argument("--host", type=str, default="0.0.0.0")
827
  parser.add_argument("--port", type=int, default=7860)
828
- parser.add_argument("--controller-url", type=str, default=None)
829
- parser.add_argument("--worker-ip", type=str, default=None)
830
  parser.add_argument("--concurrency-count", type=int, default=10)
831
- parser.add_argument(
832
- "--model-list-mode", type=str, default="reload", choices=["once", "reload"]
833
- )
834
- parser.add_argument("--sd-worker-url", type=str, default=None)
835
  parser.add_argument("--share", action="store_true")
836
  parser.add_argument("--moderate", action="store_true")
837
  parser.add_argument("--embed", action="store_true")
838
  args = parser.parse_args()
839
  logger.info(f"args: {args}")
840
- if not args.controller_url:
841
- args.controller_url = os.environ.get("CONTROLLER_URL", None)
842
-
843
- if not args.controller_url:
844
- raise ValueError("controller-url is required.")
845
-
846
- if not args.worker_ip:
847
- args.worker_ip = os.environ.get("WORKER_IP", None)
848
 
849
- model_lists = ["OpenGVLab/InternVL-Chat-V1-5", "OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-2B",
850
- "OpenGVLab/InternVL2-4B", "OpenGVLab/InternVL2-8B", "OpenGVLab/InternVL2-26B",
851
- "OpenGVLab/InternVL2-40B", "OpenGVLab/InternVL2-Llama3-76B"]
852
- sd_worker_url = args.sd_worker_url
853
  logger.info(args)
854
  demo = build_demo(args.embed)
855
  demo.queue(api_open=False).launch(
 
47
  fout.write(content)
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  get_window_url_params = """
51
  function() {
52
  const params = new URLSearchParams(window.location.search);
 
120
  return returned_image if len(matches) > 0 else None
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def vote_last_response(state, liked, model_selector, request: gr.Request):
124
  conv_data = {
125
  "tstamp": round(time.time(), 4),
 
173
  def regenerate(state, image_process_mode, request: gr.Request):
174
  logger.info(f"regenerate. ip: {request.client.host}")
175
  # state.messages[-1][-1] = None
176
+ state.update_message(Conversation.ASSISTANT, content='', image=None, idx=-1)
177
  prev_human_msg = state.messages[-2]
178
  if type(prev_human_msg[1]) in (tuple, list):
179
  prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
 
251
  ) + (no_change_btn,) * 5
252
  return
253
 
254
+ worker_addr = os.environ.get("WORKER_ADDR", "")
255
+ api_token = os.environ.get("API_TOKEN", "")
256
+ headers = {"Authorization": f"{api_token}", "Content-Type": "application/json"}
 
 
 
 
 
 
257
 
258
+ state.append_message(Conversation.ASSISTANT, state.streaming_placeholder)
259
  # No available worker
260
  if worker_addr == "":
261
  # state.messages[-1][-1] = server_error_msg
 
278
  # Make requests
279
  pload = {
280
  "model": model_name,
281
+ "messages": state.get_prompt_v2(inlude_image=True, max_dynamic_patch=max_input_tiles),
282
  "temperature": float(temperature),
283
  "top_p": float(top_p),
284
+ "max_tokens": max_new_tokens,
 
 
 
285
  "repetition_penalty": repetition_penalty,
286
+ "stream": True
287
  }
288
  logger.info(f"==== request ====\n{pload}")
 
 
 
289
  yield (
290
  state,
291
  state.to_gradio_chatbot(),
 
294
 
295
  try:
296
  # Stream output
297
+ response = requests.post(worker_addr, json=pload, headers=headers, stream=True)
298
+ finnal_output = ''
299
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\n"):
 
 
 
 
 
300
  if chunk:
301
+ chunk = chunk.decode()
302
+ if chunk == 'data: [DONE]':
303
+ break
304
+ if chunk.startswith("data:"):
305
+ chunk = chunk[5:]
306
+ chunk = json.loads(chunk)
307
+ output = chunk['choices'][0]['delta']['content']
308
+ finnal_output += output
309
+
310
+ state.update_message(Conversation.ASSISTANT, finnal_output + state.streaming_placeholder, None)
311
+ yield (
312
+ state,
313
+ state.to_gradio_chatbot(),
314
+ gr.MultimodalTextbox(interactive=False),
315
+ ) + (disable_btn,) * 5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  except requests.exceptions.RequestException as e:
317
  state.update_message(Conversation.ASSISTANT, server_error_msg, None)
318
  yield (
 
333
  returned_image = find_bounding_boxes(state, ai_response)
334
  returned_image = [returned_image] if returned_image else []
335
  state.update_message(Conversation.ASSISTANT, ai_response, returned_image)
 
 
 
 
 
 
336
 
337
  state.end_of_current_turn()
338
 
 
459
  theme=gr.themes.Default(),
460
  css=block_css,
461
  ) as demo:
462
+ models = ['InternVL2-Pro']
463
  state = gr.State()
464
 
465
  if not embed_mode:
 
679
  [state, chatbot, textbox] + btn_list,
680
  )
681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  return demo
683
 
684
 
 
686
  parser = argparse.ArgumentParser()
687
  parser.add_argument("--host", type=str, default="0.0.0.0")
688
  parser.add_argument("--port", type=int, default=7860)
 
 
689
  parser.add_argument("--concurrency-count", type=int, default=10)
 
 
 
 
690
  parser.add_argument("--share", action="store_true")
691
  parser.add_argument("--moderate", action="store_true")
692
  parser.add_argument("--embed", action="store_true")
693
  args = parser.parse_args()
694
  logger.info(f"args: {args}")
 
 
 
 
 
 
 
 
695
 
 
 
 
 
696
  logger.info(args)
697
  demo = build_demo(args.embed)
698
  demo.queue(api_open=False).launch(