Sharaf Zaman commited on
Commit
d1b18a4
1 Parent(s): f8409fb

Add file to launch the space (based on LLaVA-1.6)

Browse files
Files changed (2) hide show
  1. app.py +102 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on liuhaotian/LLaVA-1.6
2
+
3
+ import sys
4
+ import os
5
+ import argparse
6
+ import time
7
+ import subprocess
8
+
9
+ import llava.serve.gradio_web_server as gws
10
+
11
+ # Execute the pip install command with additional options
12
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
13
+
14
+
15
+ def start_controller():
16
+ print("Starting the controller")
17
+ controller_command = [
18
+ sys.executable,
19
+ "-m",
20
+ "llava.serve.controller",
21
+ "--host",
22
+ "0.0.0.0",
23
+ "--port",
24
+ "10000",
25
+ ]
26
+ print(controller_command)
27
+ return subprocess.Popen(controller_command)
28
+
29
+
30
+ def start_worker(model_path: str, bits=16):
31
+ print(f"Starting the model worker for the model {model_path}")
32
+ model_name = model_path.strip("/").split("/")[-1]
33
+ assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
34
+ if bits != 16:
35
+ model_name += f"-{bits}bit"
36
+ worker_command = [
37
+ sys.executable,
38
+ "-m",
39
+ "llava.serve.model_worker",
40
+ "--host",
41
+ "0.0.0.0",
42
+ "--controller",
43
+ "http://localhost:10000",
44
+ "--model-path",
45
+ model_path,
46
+ "--model-name",
47
+ model_name,
48
+ "--use-flash-attn",
49
+ ]
50
+ if bits != 16:
51
+ worker_command += [f"--load-{bits}bit"]
52
+ print(worker_command)
53
+ return subprocess.Popen(worker_command)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument("--host", type=str, default="0.0.0.0")
59
+ parser.add_argument("--port", type=int)
60
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
61
+ parser.add_argument("--concurrency-count", type=int, default=5)
62
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
63
+ parser.add_argument("--share", action="store_true")
64
+ parser.add_argument("--moderate", action="store_true")
65
+ parser.add_argument("--embed", action="store_true")
66
+ gws.args = parser.parse_args()
67
+ gws.models = []
68
+
69
+ gws.title_markdown += """ AstroLLaVA """
70
+
71
+ print(f"args: {gws.args}")
72
+
73
+ model_path = os.getenv("model", "universeTBD/astroLLaVA")
74
+ bits = int(os.getenv("bits", 4))
75
+ concurrency_count = int(os.getenv("concurrency_count", 5))
76
+
77
+ controller_proc = start_controller()
78
+ worker_proc = start_worker(model_path, bits=bits)
79
+
80
+ # Wait for worker and controller to start
81
+ time.sleep(10)
82
+
83
+ exit_status = 0
84
+ try:
85
+ demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
86
+ demo.queue(
87
+ status_update_rate=10,
88
+ api_open=False
89
+ ).launch(
90
+ server_name=gws.args.host,
91
+ server_port=gws.args.port,
92
+ share=gws.args.share
93
+ )
94
+
95
+ except Exception as e:
96
+ print(e)
97
+ exit_status = 1
98
+ finally:
99
+ worker_proc.kill()
100
+ controller_proc.kill()
101
+
102
+ sys.exit(exit_status)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ llava-torch