badayvedat commited on
Commit
5df3ede
ยท
1 Parent(s): c6dfdac

Remove model preloading

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +37 -31
README.md CHANGED
@@ -4,5 +4,6 @@ emoji: ๐Ÿ”ฅ
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
 
7
  app_port: 7860
8
  ---
 
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_port: 7860
9
  ---
app.py CHANGED
@@ -76,10 +76,26 @@ def load_demo_refresh_model_list(request: gr.Request):
76
  logger.info(f"load_demo. ip: {request.client.host}")
77
  models = get_model_list()
78
  state = default_conversation.copy()
79
- dropdown_update = gr.Dropdown.update(
80
- choices=models, value=models[0] if len(models) > 0 else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
- return state, dropdown_update
 
83
 
84
 
85
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
@@ -375,8 +391,8 @@ def build_demo(embed_mode):
375
  with gr.Row(elem_id="model_selector_row"):
376
  model_selector = gr.Dropdown(
377
  choices=models,
378
- value=models[0] if len(models) > 0 else "",
379
- interactive=True,
380
  show_label=False,
381
  container=False,
382
  )
@@ -438,7 +454,9 @@ def build_demo(embed_mode):
438
  with gr.Column(scale=8):
439
  textbox.render()
440
  with gr.Column(scale=1, min_width=50):
441
- submit_btn = gr.Button(value="Send", variant="primary")
 
 
442
  with gr.Row(elem_id="buttons") as button_row:
443
  upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
444
  downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
@@ -509,7 +527,9 @@ def build_demo(embed_mode):
509
  _js=get_window_url_params,
510
  )
511
  elif args.model_list_mode == "reload":
512
- demo.load(load_demo_refresh_model_list, None, [state, model_selector])
 
 
513
  else:
514
  raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
515
 
@@ -532,10 +552,10 @@ def start_controller():
532
 
533
  def start_worker(model_path: str, bits=16):
534
  logger.info(f"Starting the model worker for the model {model_path}")
535
- model_name = model_path.strip('/').split('/')[-1]
536
  assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
537
  if bits != 16:
538
- model_name += f'-{bits}bit'
539
  worker_command = [
540
  "python",
541
  "-m",
@@ -550,25 +570,10 @@ def start_worker(model_path: str, bits=16):
550
  model_name,
551
  ]
552
  if bits != 16:
553
- worker_command += [f'--load-{bits}bit']
554
  return subprocess.Popen(worker_command)
555
 
556
 
557
- def preload_models(model_path: str):
558
- import torch
559
-
560
- from llava.model import LlavaLlamaForCausalLM
561
-
562
- model = LlavaLlamaForCausalLM.from_pretrained(
563
- model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16
564
- )
565
- vision_tower = model.get_vision_tower()
566
- vision_tower.load_model()
567
-
568
- del vision_tower
569
- del model
570
-
571
-
572
  def get_args():
573
  parser = argparse.ArgumentParser()
574
  parser.add_argument("--host", type=str, default="0.0.0.0")
@@ -601,19 +606,20 @@ if __name__ == "__main__":
601
  model_path = "liuhaotian/llava-v1.5-13b"
602
  bits = int(os.getenv("bits", 8))
603
 
604
- preload_models(model_path)
605
-
606
  controller_proc = start_controller()
607
  worker_proc = start_worker(model_path, bits=bits)
608
 
609
  # Wait for worker and controller to start
610
  time.sleep(10)
611
 
 
612
  try:
613
  start_demo(args)
614
  except Exception as e:
615
- worker_proc.terminate()
616
- controller_proc.terminate()
617
-
618
  print(e)
619
- sys.exit(1)
 
 
 
 
 
 
76
  logger.info(f"load_demo. ip: {request.client.host}")
77
  models = get_model_list()
78
  state = default_conversation.copy()
79
+
80
+ models_downloaded = True if models else False
81
+
82
+ model_dropdown_kwargs = {
83
+ "choices": [],
84
+ "value": "Downloading the models...",
85
+ "interactive": models_downloaded,
86
+ }
87
+
88
+ if models_downloaded:
89
+ model_dropdown_kwargs["choices"] = models
90
+ model_dropdown_kwargs["value"] = models[0]
91
+
92
+ models_dropdown_update = gr.Dropdown.update(**model_dropdown_kwargs)
93
+
94
+ send_button_update = gr.Button.update(
95
+ interactive=models_downloaded,
96
  )
97
+
98
+ return state, models_dropdown_update, send_button_update
99
 
100
 
101
  def vote_last_response(state, vote_type, model_selector, request: gr.Request):
 
391
  with gr.Row(elem_id="model_selector_row"):
392
  model_selector = gr.Dropdown(
393
  choices=models,
394
+ value=models[0] if models else "Downloading the models...",
395
+ interactive=True if models else False,
396
  show_label=False,
397
  container=False,
398
  )
 
454
  with gr.Column(scale=8):
455
  textbox.render()
456
  with gr.Column(scale=1, min_width=50):
457
+ submit_btn = gr.Button(
458
+ value="Send", variant="primary", interactive=False
459
+ )
460
  with gr.Row(elem_id="buttons") as button_row:
461
  upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
462
  downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
 
527
  _js=get_window_url_params,
528
  )
529
  elif args.model_list_mode == "reload":
530
+ demo.load(
531
+ load_demo_refresh_model_list, None, [state, model_selector, submit_btn]
532
+ )
533
  else:
534
  raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
535
 
 
552
 
553
  def start_worker(model_path: str, bits=16):
554
  logger.info(f"Starting the model worker for the model {model_path}")
555
+ model_name = model_path.strip("/").split("/")[-1]
556
  assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
557
  if bits != 16:
558
+ model_name += f"-{bits}bit"
559
  worker_command = [
560
  "python",
561
  "-m",
 
570
  model_name,
571
  ]
572
  if bits != 16:
573
+ worker_command += [f"--load-{bits}bit"]
574
  return subprocess.Popen(worker_command)
575
 
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  def get_args():
578
  parser = argparse.ArgumentParser()
579
  parser.add_argument("--host", type=str, default="0.0.0.0")
 
606
  model_path = "liuhaotian/llava-v1.5-13b"
607
  bits = int(os.getenv("bits", 8))
608
 
 
 
609
  controller_proc = start_controller()
610
  worker_proc = start_worker(model_path, bits=bits)
611
 
612
  # Wait for worker and controller to start
613
  time.sleep(10)
614
 
615
+ exit_status = 0
616
  try:
617
  start_demo(args)
618
  except Exception as e:
 
 
 
619
  print(e)
620
+ exit_status = 1
621
+ finally:
622
+ worker_proc.kill()
623
+ controller_proc.kill()
624
+
625
+ sys.exit(exit_status)