abalakrishnaTRI commited on
Commit
5b53c67
1 Parent(s): 6ba6dce
Files changed (1) hide show
  1. interactive_demo.py +16 -39
interactive_demo.py CHANGED
@@ -47,20 +47,12 @@ def heart_beat_worker(controller):
47
 
48
 
49
  class ModelWorker:
50
- def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_base, model_name):
51
  self.controller_addr = controller_addr
52
  self.worker_addr = worker_addr
53
  self.worker_id = worker_id
54
  self.model_name = model_name
55
-
56
- # logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
57
  self.vlm = vlm
58
- self.tokenizer, self.model, self.image_processor, self.context_len = (
59
- vlm.tokenizer,
60
- vlm.model,
61
- vlm.image_processor,
62
- vlm.max_length,
63
- )
64
 
65
  if not no_register:
66
  self.register_to_controller()
@@ -68,18 +60,12 @@ class ModelWorker:
68
  self.heart_beat_thread.start()
69
 
70
  def register_to_controller(self):
71
- # logger.info("Register to controller")
72
-
73
  url = self.controller_addr + "/register_worker"
74
  data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
75
  r = requests.post(url, json=data)
76
  assert r.status_code == 200
77
 
78
  def send_heart_beat(self):
79
- # logger.info(f"Send heart beat. Models: {[self.model_name]}. "
80
- # f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
81
- # f"global_counter: {global_counter}")
82
-
83
  url = self.controller_addr + "/receive_heart_beat"
84
 
85
  while True:
@@ -91,7 +77,6 @@ class ModelWorker:
91
  break
92
  except requests.exceptions.RequestException:
93
  pass
94
- # logger.error(f"heart beat error: {e}")
95
  time.sleep(5)
96
 
97
  if not exist:
@@ -145,12 +130,12 @@ class ModelWorker:
145
  else:
146
  question_prompt = [prompt_fn()]
147
 
148
- if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"):
149
  # This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
150
- pixel_values = self.image_processor(images[0].convert("RGB"))
151
  else:
152
  # Assume `image_transform` is a HF ImageProcessor...
153
- pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
154
 
155
  if type(pixel_values) is dict:
156
  for k in pixel_values.keys():
@@ -227,31 +212,29 @@ overwatch = initialize_overwatch(__name__)
227
  class DemoConfig:
228
  # fmt: off
229
 
230
- # === Model Parameters =>> Quartz ===
231
- model_family: str = "quartz" # Model family to load from in < `quartz` | `llava-v15` | ... >
232
- model_id: str = "llava-v1.5-7b" # Model ID to load and run (instance of `model_family`)
233
- model_dir: Path = ( # Path to model checkpoint to load --> should be self-contained
234
- "resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7"
235
- )
236
 
237
  # === Model Parameters =>> Official LLaVa ===
238
  # model_family: str = "llava-v15"
239
  # model_id: str = "llava-v1.5-13b"
240
  # model_dir: Path = "liuhaotian/llava-v1.5-13b"
241
 
 
 
 
 
 
242
  # Model Worker Parameters
243
  host: str = "0.0.0.0"
244
  port: int = 40000
245
  controller_address: str = "http://localhost:10000"
246
- model_base: str = "llava-v15"
247
  limit_model_concurrency: int = 5
248
  stream_interval: int = 1
249
  no_register: bool = False
250
 
251
- # Inference Parameters
252
- device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs!
253
- num_workers: int = 2 # Number of Dataloader Workers (on each process)
254
-
255
  # HF Hub Credentials (for LLaMa-2)
256
  hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
257
 
@@ -259,14 +242,8 @@ class DemoConfig:
259
  seed: int = 21 # Random Seed (for reproducibility)
260
 
261
  def __post_init__(self) -> None:
262
- if self.model_family == "quartz":
263
- self.model_name = MODEL_ID_TO_NAME[str(self.model_dir)]
264
- self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir
265
- elif self.model_family in {"instruct-blip", "llava", "llava-v15"}:
266
- self.model_name = MODEL_ID_TO_NAME[self.model_id]
267
- self.run_dir = self.model_dir
268
- else:
269
- raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!")
270
  self.worker_address = f"http://localhost:{self.port}"
271
 
272
  # fmt: on
@@ -286,7 +263,7 @@ def interactive_demo(cfg: DemoConfig):
286
  global limit_model_concurrency
287
  limit_model_concurrency = cfg.limit_model_concurrency
288
  worker = ModelWorker(
289
- cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_base, cfg.model_name
290
  )
291
  uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
292
 
 
47
 
48
 
49
  class ModelWorker:
50
+ def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_name):
51
  self.controller_addr = controller_addr
52
  self.worker_addr = worker_addr
53
  self.worker_id = worker_id
54
  self.model_name = model_name
 
 
55
  self.vlm = vlm
 
 
 
 
 
 
56
 
57
  if not no_register:
58
  self.register_to_controller()
 
60
  self.heart_beat_thread.start()
61
 
62
  def register_to_controller(self):
 
 
63
  url = self.controller_addr + "/register_worker"
64
  data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
65
  r = requests.post(url, json=data)
66
  assert r.status_code == 200
67
 
68
  def send_heart_beat(self):
 
 
 
 
69
  url = self.controller_addr + "/receive_heart_beat"
70
 
71
  while True:
 
77
  break
78
  except requests.exceptions.RequestException:
79
  pass
 
80
  time.sleep(5)
81
 
82
  if not exist:
 
130
  else:
131
  question_prompt = [prompt_fn()]
132
 
133
+ if isinstance(self.vlm.image_processor, Compose) or hasattr(self.vlm.image_processor, "is_prismatic"):
134
  # This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
135
+ pixel_values = self.vlm.image_processor(images[0].convert("RGB"))
136
  else:
137
  # Assume `image_transform` is a HF ImageProcessor...
138
+ pixel_values = self.vlm.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
139
 
140
  if type(pixel_values) is dict:
141
  for k in pixel_values.keys():
 
212
  class DemoConfig:
213
  # fmt: off
214
 
215
+ # === Model Parameters =>> Prismatic ===
216
+ model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
217
+ model_id: str = "prism-dinosiglip+7b" # Model ID to load and run (instance of `model_family`)
218
+ model_dir: str = None # Can optionally supply model_dir instead of model_id
 
 
219
 
220
  # === Model Parameters =>> Official LLaVa ===
221
  # model_family: str = "llava-v15"
222
  # model_id: str = "llava-v1.5-13b"
223
  # model_dir: Path = "liuhaotian/llava-v1.5-13b"
224
 
225
+ # === Model Parameters =>> Official InstructBLIP ===
226
+ # model_family: str = "instruct-blip"
227
+ # model_id: str = "instructblip-vicuna-7b"
228
+ # model_dir: Path = "Salesforce/instructblip-vicuna-7b"
229
+
230
  # Model Worker Parameters
231
  host: str = "0.0.0.0"
232
  port: int = 40000
233
  controller_address: str = "http://localhost:10000"
 
234
  limit_model_concurrency: int = 5
235
  stream_interval: int = 1
236
  no_register: bool = False
237
 
 
 
 
 
238
  # HF Hub Credentials (for LLaMa-2)
239
  hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
240
 
 
242
  seed: int = 21 # Random Seed (for reproducibility)
243
 
244
  def __post_init__(self) -> None:
245
+ self.run_dir = self.model_dir
246
+ self.model_name = MODEL_ID_TO_NAME[str(self.model_id)]
 
 
 
 
 
 
247
  self.worker_address = f"http://localhost:{self.port}"
248
 
249
  # fmt: on
 
263
  global limit_model_concurrency
264
  limit_model_concurrency = cfg.limit_model_concurrency
265
  worker = ModelWorker(
266
+ cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_name
267
  )
268
  uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
269