Merge pull request #2 from TRI-ML/master
Browse files- interactive_demo.py +16 -39
- serve/__init__.py +7 -14
- serve/gradio_web_server.py +0 -27
interactive_demo.py
CHANGED
@@ -47,20 +47,12 @@ def heart_beat_worker(controller):
|
|
47 |
|
48 |
|
49 |
class ModelWorker:
|
50 |
-
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm,
|
51 |
self.controller_addr = controller_addr
|
52 |
self.worker_addr = worker_addr
|
53 |
self.worker_id = worker_id
|
54 |
self.model_name = model_name
|
55 |
-
|
56 |
-
# logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
57 |
self.vlm = vlm
|
58 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = (
|
59 |
-
vlm.tokenizer,
|
60 |
-
vlm.model,
|
61 |
-
vlm.image_processor,
|
62 |
-
vlm.max_length,
|
63 |
-
)
|
64 |
|
65 |
if not no_register:
|
66 |
self.register_to_controller()
|
@@ -68,18 +60,12 @@ class ModelWorker:
|
|
68 |
self.heart_beat_thread.start()
|
69 |
|
70 |
def register_to_controller(self):
|
71 |
-
# logger.info("Register to controller")
|
72 |
-
|
73 |
url = self.controller_addr + "/register_worker"
|
74 |
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
|
75 |
r = requests.post(url, json=data)
|
76 |
assert r.status_code == 200
|
77 |
|
78 |
def send_heart_beat(self):
|
79 |
-
# logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
80 |
-
# f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
81 |
-
# f"global_counter: {global_counter}")
|
82 |
-
|
83 |
url = self.controller_addr + "/receive_heart_beat"
|
84 |
|
85 |
while True:
|
@@ -91,7 +77,6 @@ class ModelWorker:
|
|
91 |
break
|
92 |
except requests.exceptions.RequestException:
|
93 |
pass
|
94 |
-
# logger.error(f"heart beat error: {e}")
|
95 |
time.sleep(5)
|
96 |
|
97 |
if not exist:
|
@@ -145,12 +130,12 @@ class ModelWorker:
|
|
145 |
else:
|
146 |
question_prompt = [prompt_fn()]
|
147 |
|
148 |
-
if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"):
|
149 |
# This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
|
150 |
-
pixel_values = self.image_processor(images[0].convert("RGB"))
|
151 |
else:
|
152 |
# Assume `image_transform` is a HF ImageProcessor...
|
153 |
-
pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
|
154 |
|
155 |
if type(pixel_values) is dict:
|
156 |
for k in pixel_values.keys():
|
@@ -227,31 +212,29 @@ overwatch = initialize_overwatch(__name__)
|
|
227 |
class DemoConfig:
|
228 |
# fmt: off
|
229 |
|
230 |
-
# === Model Parameters =>>
|
231 |
-
model_family: str = "
|
232 |
-
model_id: str = "
|
233 |
-
model_dir:
|
234 |
-
"resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7"
|
235 |
-
)
|
236 |
|
237 |
# === Model Parameters =>> Official LLaVa ===
|
238 |
# model_family: str = "llava-v15"
|
239 |
# model_id: str = "llava-v1.5-13b"
|
240 |
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
|
241 |
|
|
|
|
|
|
|
|
|
|
|
242 |
# Model Worker Parameters
|
243 |
host: str = "0.0.0.0"
|
244 |
port: int = 40000
|
245 |
controller_address: str = "http://localhost:10000"
|
246 |
-
model_base: str = "llava-v15"
|
247 |
limit_model_concurrency: int = 5
|
248 |
stream_interval: int = 1
|
249 |
no_register: bool = False
|
250 |
|
251 |
-
# Inference Parameters
|
252 |
-
device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs!
|
253 |
-
num_workers: int = 2 # Number of Dataloader Workers (on each process)
|
254 |
-
|
255 |
# HF Hub Credentials (for LLaMa-2)
|
256 |
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
257 |
|
@@ -259,14 +242,8 @@ class DemoConfig:
|
|
259 |
seed: int = 21 # Random Seed (for reproducibility)
|
260 |
|
261 |
def __post_init__(self) -> None:
|
262 |
-
|
263 |
-
|
264 |
-
self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir
|
265 |
-
elif self.model_family in {"instruct-blip", "llava", "llava-v15"}:
|
266 |
-
self.model_name = MODEL_ID_TO_NAME[self.model_id]
|
267 |
-
self.run_dir = self.model_dir
|
268 |
-
else:
|
269 |
-
raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!")
|
270 |
self.worker_address = f"http://localhost:{self.port}"
|
271 |
|
272 |
# fmt: on
|
@@ -286,7 +263,7 @@ def interactive_demo(cfg: DemoConfig):
|
|
286 |
global limit_model_concurrency
|
287 |
limit_model_concurrency = cfg.limit_model_concurrency
|
288 |
worker = ModelWorker(
|
289 |
-
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.
|
290 |
)
|
291 |
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
|
292 |
|
|
|
47 |
|
48 |
|
49 |
class ModelWorker:
|
50 |
+
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_name):
|
51 |
self.controller_addr = controller_addr
|
52 |
self.worker_addr = worker_addr
|
53 |
self.worker_id = worker_id
|
54 |
self.model_name = model_name
|
|
|
|
|
55 |
self.vlm = vlm
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
if not no_register:
|
58 |
self.register_to_controller()
|
|
|
60 |
self.heart_beat_thread.start()
|
61 |
|
62 |
def register_to_controller(self):
|
|
|
|
|
63 |
url = self.controller_addr + "/register_worker"
|
64 |
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
|
65 |
r = requests.post(url, json=data)
|
66 |
assert r.status_code == 200
|
67 |
|
68 |
def send_heart_beat(self):
|
|
|
|
|
|
|
|
|
69 |
url = self.controller_addr + "/receive_heart_beat"
|
70 |
|
71 |
while True:
|
|
|
77 |
break
|
78 |
except requests.exceptions.RequestException:
|
79 |
pass
|
|
|
80 |
time.sleep(5)
|
81 |
|
82 |
if not exist:
|
|
|
130 |
else:
|
131 |
question_prompt = [prompt_fn()]
|
132 |
|
133 |
+
if isinstance(self.vlm.image_processor, Compose) or hasattr(self.vlm.image_processor, "is_prismatic"):
|
134 |
# This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
|
135 |
+
pixel_values = self.vlm.image_processor(images[0].convert("RGB"))
|
136 |
else:
|
137 |
# Assume `image_transform` is a HF ImageProcessor...
|
138 |
+
pixel_values = self.vlm.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
|
139 |
|
140 |
if type(pixel_values) is dict:
|
141 |
for k in pixel_values.keys():
|
|
|
212 |
class DemoConfig:
|
213 |
# fmt: off
|
214 |
|
215 |
+
# === Model Parameters =>> Prismatic ===
|
216 |
+
model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
|
217 |
+
model_id: str = "prism-dinosiglip+7b" # Model ID to load and run (instance of `model_family`)
|
218 |
+
model_dir: str = None # Can optionally supply model_dir instead of model_id
|
|
|
|
|
219 |
|
220 |
# === Model Parameters =>> Official LLaVa ===
|
221 |
# model_family: str = "llava-v15"
|
222 |
# model_id: str = "llava-v1.5-13b"
|
223 |
# model_dir: Path = "liuhaotian/llava-v1.5-13b"
|
224 |
|
225 |
+
# === Model Parameters =>> Official InstructBLIP ===
|
226 |
+
# model_family: str = "instruct-blip"
|
227 |
+
# model_id: str = "instructblip-vicuna-7b"
|
228 |
+
# model_dir: Path = "Salesforce/instructblip-vicuna-7b"
|
229 |
+
|
230 |
# Model Worker Parameters
|
231 |
host: str = "0.0.0.0"
|
232 |
port: int = 40000
|
233 |
controller_address: str = "http://localhost:10000"
|
|
|
234 |
limit_model_concurrency: int = 5
|
235 |
stream_interval: int = 1
|
236 |
no_register: bool = False
|
237 |
|
|
|
|
|
|
|
|
|
238 |
# HF Hub Credentials (for LLaMa-2)
|
239 |
hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
|
240 |
|
|
|
242 |
seed: int = 21 # Random Seed (for reproducibility)
|
243 |
|
244 |
def __post_init__(self) -> None:
|
245 |
+
self.run_dir = self.model_dir
|
246 |
+
self.model_name = MODEL_ID_TO_NAME[str(self.model_id)]
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
self.worker_address = f"http://localhost:{self.port}"
|
248 |
|
249 |
# fmt: on
|
|
|
263 |
global limit_model_concurrency
|
264 |
limit_model_concurrency = cfg.limit_model_concurrency
|
265 |
worker = ModelWorker(
|
266 |
+
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_name
|
267 |
)
|
268 |
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
|
269 |
|
serve/__init__.py
CHANGED
@@ -5,31 +5,24 @@ from collections import OrderedDict
|
|
5 |
MODEL_ID_TO_NAME = OrderedDict(
|
6 |
[
|
7 |
(
|
8 |
-
"
|
9 |
-
"PrismaticVLM 13B - Chat",
|
10 |
-
),
|
11 |
-
(
|
12 |
-
"llava-lvis4v-lrv+redux-lvis4v-lrv-resize-naive-dinosiglip-vit-so-14-384px-no-align+7b+stage-finetune+x7",
|
13 |
-
"PrismaticVLM 7B - Chat",
|
14 |
-
),
|
15 |
-
(
|
16 |
-
"llava-lvis4v-lrv+redux-lvis4v-lrv-resize-naive-dinosiglip-vit-so-14-384px-no-align-llama2pure+13b+stage-finetune+x7",
|
17 |
"PrismaticVLM 13B",
|
18 |
),
|
19 |
(
|
20 |
-
"
|
21 |
"PrismaticVLM 7B",
|
22 |
),
|
23 |
(
|
24 |
-
"
|
25 |
"PrismaticVLM 13B (Controlled)",
|
26 |
),
|
27 |
(
|
28 |
-
"
|
29 |
"PrismaticVLM 7B (Controlled)",
|
30 |
),
|
31 |
-
("llava-v1.5-13b", "LLaVA 1.5
|
32 |
-
("llava-v1.5-7b", "LLaVA 1.5
|
|
|
33 |
]
|
34 |
)
|
35 |
|
|
|
5 |
MODEL_ID_TO_NAME = OrderedDict(
|
6 |
[
|
7 |
(
|
8 |
+
"prism-dinosiglip+13b",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
"PrismaticVLM 13B",
|
10 |
),
|
11 |
(
|
12 |
+
"prism-dinosiglip+7b",
|
13 |
"PrismaticVLM 7B",
|
14 |
),
|
15 |
(
|
16 |
+
"prism-dinosiglip-controlled+13b",
|
17 |
"PrismaticVLM 13B (Controlled)",
|
18 |
),
|
19 |
(
|
20 |
+
"prism-dinosiglip-controlled+7b",
|
21 |
"PrismaticVLM 7B (Controlled)",
|
22 |
),
|
23 |
+
("llava-v1.5-13b", "LLaVA 1.5 13B"),
|
24 |
+
("llava-v1.5-7b", "LLaVA 1.5 7B"),
|
25 |
+
("instructblip-vicuna-7b", "InstructBLIP 7B"),
|
26 |
]
|
27 |
)
|
28 |
|
serve/gradio_web_server.py
CHANGED
@@ -93,24 +93,6 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
|
93 |
fout.write(json.dumps(data) + "\n")
|
94 |
|
95 |
|
96 |
-
# def upvote_last_response(state, model_selector, request: gr.Request):
|
97 |
-
# logger.info(f"upvote. ip: {request.client.host}")
|
98 |
-
# vote_last_response(state, "upvote", model_selector, request)
|
99 |
-
# return ("",) + (disable_btn,) * 3
|
100 |
-
|
101 |
-
|
102 |
-
# def downvote_last_response(state, model_selector, request: gr.Request):
|
103 |
-
# logger.info(f"downvote. ip: {request.client.host}")
|
104 |
-
# vote_last_response(state, "downvote", model_selector, request)
|
105 |
-
# return ("",) + (disable_btn,) * 3
|
106 |
-
|
107 |
-
|
108 |
-
# def flag_last_response(state, model_selector, request: gr.Request):
|
109 |
-
# logger.info(f"flag. ip: {request.client.host}")
|
110 |
-
# vote_last_response(state, "flag", model_selector, request)
|
111 |
-
# return ("",) + (disable_btn,) * 3
|
112 |
-
|
113 |
-
|
114 |
def regenerate(state, image_process_mode, request: gr.Request):
|
115 |
logger.info(f"regenerate. ip: {request.client.host}")
|
116 |
state.messages[-1][-1] = None
|
@@ -388,15 +370,6 @@ def build_demo(embed_mode):
|
|
388 |
|
389 |
# Register listeners
|
390 |
btn_list = [regenerate_btn, clear_btn]
|
391 |
-
# upvote_btn.click(
|
392 |
-
# upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
|
393 |
-
# )
|
394 |
-
# downvote_btn.click(
|
395 |
-
# downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
|
396 |
-
# )
|
397 |
-
# flag_btn.click(
|
398 |
-
# flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
|
399 |
-
# )
|
400 |
|
401 |
regenerate_btn.click(
|
402 |
regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, *btn_list], queue=False
|
|
|
93 |
fout.write(json.dumps(data) + "\n")
|
94 |
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def regenerate(state, image_process_mode, request: gr.Request):
|
97 |
logger.info(f"regenerate. ip: {request.client.host}")
|
98 |
state.messages[-1][-1] = None
|
|
|
370 |
|
371 |
# Register listeners
|
372 |
btn_list = [regenerate_btn, clear_btn]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
regenerate_btn.click(
|
375 |
regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, *btn_list], queue=False
|