Spaces:
Sleeping
Sleeping
Nick Vandal
commited on
Commit
•
6c420e0
1
Parent(s):
0c371b7
added mupliple models and revisions
Browse files
LLaVA
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit 3c2f6ba15ed0477f4149fd582d2b640e19da2a57
|
app.py
CHANGED
@@ -25,7 +25,7 @@ def start_controller():
|
|
25 |
return subprocess.Popen(controller_command)
|
26 |
|
27 |
|
28 |
-
def start_worker(model_path: str, bits=16):
|
29 |
print(f"Starting the model worker for the model {model_path}")
|
30 |
model_name = model_path.strip("/").split("/")[-1]
|
31 |
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
|
@@ -37,6 +37,10 @@ def start_worker(model_path: str, bits=16):
|
|
37 |
"llava.serve.model_worker",
|
38 |
"--host",
|
39 |
"0.0.0.0",
|
|
|
|
|
|
|
|
|
40 |
"--controller",
|
41 |
"http://localhost:10000",
|
42 |
"--model-path",
|
@@ -44,6 +48,8 @@ def start_worker(model_path: str, bits=16):
|
|
44 |
"--model-name",
|
45 |
model_name,
|
46 |
"--use-flash-attn",
|
|
|
|
|
47 |
]
|
48 |
if bits != 16:
|
49 |
worker_command += [f"--load-{bits}bit"]
|
@@ -77,12 +83,21 @@ Set the environment variable `model` to change the model:
|
|
77 |
|
78 |
print(f"args: {gws.args}")
|
79 |
|
80 |
-
|
|
|
81 |
bits = int(os.getenv("bits", 4))
|
82 |
concurrency_count = int(os.getenv("concurrency_count", 5))
|
83 |
|
84 |
controller_proc = start_controller()
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Wait for worker and controller to start
|
88 |
time.sleep(10)
|
@@ -103,7 +118,8 @@ Set the environment variable `model` to change the model:
|
|
103 |
print(e)
|
104 |
exit_status = 1
|
105 |
finally:
|
106 |
-
worker_proc
|
|
|
107 |
controller_proc.kill()
|
108 |
|
109 |
sys.exit(exit_status)
|
|
|
25 |
return subprocess.Popen(controller_command)
|
26 |
|
27 |
|
28 |
+
def start_worker(model_path: str, bits=16, revision='main', port=21002):
|
29 |
print(f"Starting the model worker for the model {model_path}")
|
30 |
model_name = model_path.strip("/").split("/")[-1]
|
31 |
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
|
|
|
37 |
"llava.serve.model_worker",
|
38 |
"--host",
|
39 |
"0.0.0.0",
|
40 |
+
"--port",
|
41 |
+
port,
|
42 |
+
"--worker-address",
|
43 |
+
f"http://127.0.0.1:{port}",
|
44 |
"--controller",
|
45 |
"http://localhost:10000",
|
46 |
"--model-path",
|
|
|
48 |
"--model-name",
|
49 |
model_name,
|
50 |
"--use-flash-attn",
|
51 |
+
"--revision",
|
52 |
+
revision
|
53 |
]
|
54 |
if bits != 16:
|
55 |
worker_command += [f"--load-{bits}bit"]
|
|
|
83 |
|
84 |
print(f"args: {gws.args}")
|
85 |
|
86 |
+
model_paths = os.getenv("model", "nvandal/LLaVA-Med-v1.5-7b")
|
87 |
+
revisions = os.getenv("revision", "main")
|
88 |
bits = int(os.getenv("bits", 4))
|
89 |
concurrency_count = int(os.getenv("concurrency_count", 5))
|
90 |
|
91 |
controller_proc = start_controller()
|
92 |
+
start_worker_port = 21002
|
93 |
+
|
94 |
+
model_paths = model_paths.split(';')
|
95 |
+
revisions = revisions.split(';')
|
96 |
+
assert(len(model_paths)==len(revisions))
|
97 |
+
worker_proc = [None]*len(model_paths)
|
98 |
+
for i, (model_path, revision) in enumerate(zip(model_paths,revisions)):
|
99 |
+
print(model_path, revision)
|
100 |
+
worker_proc[i] = start_worker(model_path, bits=bits, revision=revision, port=str(start_worker_port+i))
|
101 |
|
102 |
# Wait for worker and controller to start
|
103 |
time.sleep(10)
|
|
|
118 |
print(e)
|
119 |
exit_status = 1
|
120 |
finally:
|
121 |
+
for w in worker_proc:
|
122 |
+
w.kill()
|
123 |
controller_proc.kill()
|
124 |
|
125 |
sys.exit(exit_status)
|