liuhaotian commited on
Commit
2043a67
1 Parent(s): 255cd6e

Load 13B model with 8-bit/4-bit quantization to support more hardwares

Browse files
Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -325,6 +325,15 @@ title_markdown = """
325
  [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
326
 
327
  ONLY WORKS WITH GPU!
 
 
 
 
 
 
 
 
 
328
  """
329
 
330
  tos_markdown = """
@@ -522,8 +531,12 @@ def start_controller():
522
  return subprocess.Popen(controller_command)
523
 
524
 
525
- def start_worker(model_path: str):
526
  logger.info(f"Starting the model worker for the model {model_path}")
 
 
 
 
527
  worker_command = [
528
  "python",
529
  "-m",
@@ -534,7 +547,11 @@ def start_worker(model_path: str):
534
  "http://localhost:10000",
535
  "--model-path",
536
  model_path,
 
 
537
  ]
 
 
538
  return subprocess.Popen(worker_command)
539
 
540
 
@@ -582,12 +599,13 @@ if __name__ == "__main__":
582
  args = get_args()
583
  logger.info(f"args: {args}")
584
 
585
- model_path = "liuhaotian/llava-v1.5-7b"
 
586
 
587
  preload_models(model_path)
588
 
589
  controller_proc = start_controller()
590
- worker_proc = start_worker(model_path)
591
 
592
  # Wait for worker and controller to start
593
  time.sleep(10)
 
325
  [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
326
 
327
  ONLY WORKS WITH GPU!
328
+
329
+ You can load the model with 8-bit or 4-bit quantization to make it fit in smaller hardwares. Setting the environment variable `bits` to control the quantization.
330
+
331
+ Recommended configurations:
332
+ | Hardware | Bits |
333
+ |--------------------|----------------|
334
+ | A10G-Large (24G) | 8 (default) |
335
+ | T4-Medium (15G) | 4 |
336
+ | A100-Large (40G) | 16 |
337
  """
338
 
339
  tos_markdown = """
 
531
  return subprocess.Popen(controller_command)
532
 
533
 
534
+ def start_worker(model_path: str, bits=16):
535
  logger.info(f"Starting the model worker for the model {model_path}")
536
+ model_name = model_path.strip('/').split('/')[-1]
537
+ assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
538
+ if bits != 16:
539
+ model_name += f'-{bits}bit'
540
  worker_command = [
541
  "python",
542
  "-m",
 
547
  "http://localhost:10000",
548
  "--model-path",
549
  model_path,
550
+ "--model-name",
551
+ model_name,
552
  ]
553
+ if bits != 16:
554
+ worker_command += [f'--load-{bits}bit']
555
  return subprocess.Popen(worker_command)
556
 
557
 
 
599
  args = get_args()
600
  logger.info(f"args: {args}")
601
 
602
+ model_path = "liuhaotian/llava-v1.5-13b"
603
+ bits = int(os.getenv("bits", 8))
604
 
605
  preload_models(model_path)
606
 
607
  controller_proc = start_controller()
608
+ worker_proc = start_worker(model_path, bits=bits)
609
 
610
  # Wait for worker and controller to start
611
  time.sleep(10)