Spaces:
Sleeping
Sleeping
liuhaotian
commited on
Commit
•
14f034e
1
Parent(s):
3e4d21c
Init
Browse files- app.py +93 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import time
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
import llava.serve.gradio_web_server as gws
|
8 |
+
|
9 |
+
|
10 |
+
def start_controller():
|
11 |
+
print("Starting the controller")
|
12 |
+
controller_command = [
|
13 |
+
"python",
|
14 |
+
"-m",
|
15 |
+
"llava.serve.controller",
|
16 |
+
"--host",
|
17 |
+
"0.0.0.0",
|
18 |
+
"--port",
|
19 |
+
"10000",
|
20 |
+
]
|
21 |
+
return subprocess.Popen(controller_command)
|
22 |
+
|
23 |
+
|
24 |
+
def start_worker(model_path: str, bits=16):
|
25 |
+
print(f"Starting the model worker for the model {model_path}")
|
26 |
+
model_name = model_path.strip("/").split("/")[-1]
|
27 |
+
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
|
28 |
+
if bits != 16:
|
29 |
+
model_name += f"-{bits}bit"
|
30 |
+
worker_command = [
|
31 |
+
"python",
|
32 |
+
"-m",
|
33 |
+
"llava.serve.model_worker",
|
34 |
+
"--host",
|
35 |
+
"0.0.0.0",
|
36 |
+
"--controller",
|
37 |
+
"http://localhost:10000",
|
38 |
+
"--model-path",
|
39 |
+
model_path,
|
40 |
+
"--model-name",
|
41 |
+
model_name,
|
42 |
+
]
|
43 |
+
if bits != 16:
|
44 |
+
worker_command += [f"--load-{bits}bit"]
|
45 |
+
return subprocess.Popen(worker_command)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
51 |
+
parser.add_argument("--port", type=int)
|
52 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
|
53 |
+
parser.add_argument("--concurrency-count", type=int, default=5)
|
54 |
+
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
|
55 |
+
parser.add_argument("--share", action="store_true")
|
56 |
+
parser.add_argument("--moderate", action="store_true")
|
57 |
+
parser.add_argument("--embed", action="store_true")
|
58 |
+
gws.args = parser.parse_args()
|
59 |
+
gws.models = []
|
60 |
+
|
61 |
+
print(f"args: {gws.args}")
|
62 |
+
|
63 |
+
model_path = "liuhaotian/llava-v1.6-mistral-7b"
|
64 |
+
bits = int(os.getenv("bits", 4))
|
65 |
+
concurrency_count = int(os.getenv("concurrency_count", 5))
|
66 |
+
|
67 |
+
controller_proc = start_controller()
|
68 |
+
worker_proc = start_worker(model_path, bits=bits)
|
69 |
+
|
70 |
+
# Wait for worker and controller to start
|
71 |
+
time.sleep(10)
|
72 |
+
|
73 |
+
exit_status = 0
|
74 |
+
try:
|
75 |
+
demo = gws.build_demo(embed_mode=False, cur_dir='./')
|
76 |
+
demo.queue(
|
77 |
+
concurrency_count=concurrency_count,
|
78 |
+
status_update_rate=10,
|
79 |
+
api_open=False
|
80 |
+
).launch(
|
81 |
+
server_name=gws.args.host,
|
82 |
+
server_port=gws.args.port,
|
83 |
+
share=gws.args.share
|
84 |
+
)
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
print(e)
|
88 |
+
exit_status = 1
|
89 |
+
finally:
|
90 |
+
worker_proc.kill()
|
91 |
+
controller_proc.kill()
|
92 |
+
|
93 |
+
sys.exit(exit_status)
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
llava-torch==1.2.1.post1
|
2 |
+
protobuf==4.23.3
|