David Day commited on
Commit
4d8a17e
1 Parent(s): 9e1deca

Fix ZeroGPT issue.

Browse files
Files changed (2) hide show
  1. app.py +21 -131
  2. model_worker.py +4 -144
app.py CHANGED
@@ -3,8 +3,6 @@ import datetime
3
  import hashlib
4
  import json
5
  import os
6
- import subprocess
7
- import sys
8
  import time
9
 
10
  import gradio as gr
@@ -13,22 +11,17 @@ import requests
13
  from constants import LOGDIR
14
  from conversation import (default_conversation, conv_templates,
15
  SeparatorStyle)
16
- from utils import (build_logger, server_error_msg,
17
- violates_moderation, moderation_msg)
18
 
19
 
20
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
21
 
22
- headers = {"User-Agent": "LLaVA Client"}
23
 
24
  no_change_btn = gr.Button()
25
  enable_btn = gr.Button(interactive=True)
26
  disable_btn = gr.Button(interactive=False)
27
 
28
- priority = {
29
- "vicuna-13b": "aaaaaaa",
30
- "koala-13b": "aaaaaab",
31
- }
32
 
33
 
34
  def get_conv_log_filename():
@@ -36,17 +29,6 @@ def get_conv_log_filename():
36
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
37
  return name
38
 
39
-
40
- def get_model_list():
41
- ret = requests.post(args.controller_url + "/refresh_all_workers")
42
- assert ret.status_code == 200
43
- ret = requests.post(args.controller_url + "/list_models")
44
- models = ret.json()["models"]
45
- models.sort(key=lambda x: priority.get(x, x))
46
- logger.info(f"Models: {models}")
47
- return models
48
-
49
-
50
  get_window_url_params = """
51
  function() {
52
  const params = new URLSearchParams(window.location.search);
@@ -60,24 +42,11 @@ function() {
60
  def load_demo(url_params, request: gr.Request):
61
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
62
 
 
63
  dropdown_update = gr.Dropdown(visible=True)
64
- if "model" in url_params:
65
- model = url_params["model"]
66
- if model in models:
67
- dropdown_update = gr.Dropdown(value=model, visible=True)
68
-
69
- state = default_conversation.copy()
70
- return state, dropdown_update
71
-
72
 
73
- def load_demo_refresh_model_list(request: gr.Request):
74
- logger.info(f"load_demo. ip: {request.client.host}")
75
- models = get_model_list()
76
  state = default_conversation.copy()
77
- dropdown_update = gr.Dropdown(
78
- choices=models,
79
- value=models[0] if len(models) > 0 else ""
80
- )
81
  return state, dropdown_update
82
 
83
 
@@ -132,12 +101,6 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
132
  if len(text) <= 0 and image is None:
133
  state.skip_next = True
134
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
135
- if args.moderate:
136
- flagged = violates_moderation(text)
137
- if flagged:
138
- state.skip_next = True
139
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
140
- no_change_btn,) * 5
141
 
142
  text = text[:1536] # Hard cut-off
143
  if image is not None:
@@ -152,7 +115,6 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
152
  state.skip_next = False
153
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
 
155
-
156
  def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
  logger.info(f"http_bot. ip: {request.client.host}")
158
  start_tstamp = time.time()
@@ -204,18 +166,6 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
204
  new_state.append_message(new_state.roles[1], None)
205
  state = new_state
206
 
207
- # Query worker address
208
- controller_url = args.controller_url
209
- ret = requests.post(controller_url + "/get_worker_address",
210
- json={"model": model_name})
211
- worker_addr = ret.json()["address"]
212
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
213
-
214
- # No available worker
215
- if worker_addr == "":
216
- state.messages[-1][-1] = server_error_msg
217
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
218
- return
219
 
220
  # Construct prompt
221
  prompt = state.get_prompt()
@@ -248,9 +198,7 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request:
248
 
249
  try:
250
  # Stream output
251
- response = requests.post(worker_addr + "/worker_generate_stream",
252
- headers=headers, json=pload, stream=True, timeout=10)
253
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
254
  if chunk:
255
  data = json.loads(chunk.decode())
256
  if data["error_code"] == 0:
@@ -312,14 +260,12 @@ block_css = """
312
 
313
  """
314
 
315
- def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
316
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
317
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
318
  state = gr.State()
319
 
320
- if not embed_mode:
321
- gr.Markdown(title_markdown)
322
-
323
  with gr.Row():
324
  with gr.Column(scale=2):
325
  # add a description
@@ -381,9 +327,8 @@ This is the demo for Dr-LLaVA. So far it could only be used for H&E stained Bone
381
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
382
  clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
383
 
384
- if not embed_mode:
385
- gr.Markdown(tos_markdown)
386
- gr.Markdown(learn_more_markdown)
387
  url_params = gr.JSON(visible=False)
388
 
389
  # Register listeners
@@ -444,84 +389,29 @@ This is the demo for Dr-LLaVA. So far it could only be used for H&E stained Bone
444
  [state, chatbot] + btn_list,
445
  concurrency_limit=concurrency_count
446
  )
447
-
448
- if args.model_list_mode == "once":
449
- demo.load(
450
- load_demo,
451
- [url_params],
452
- [state, model_selector],
453
- js=get_window_url_params
454
- )
455
- elif args.model_list_mode == "reload":
456
- demo.load(
457
- load_demo_refresh_model_list,
458
- None,
459
- [state, model_selector],
460
- queue=False
461
- )
462
- else:
463
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
464
-
465
  return demo
466
 
467
- def start_controller():
468
- logger.info("Starting the controller")
469
- controller_command = [
470
- "python",
471
- "-m",
472
- "controller",
473
- "--host",
474
- "0.0.0.0",
475
- "--port",
476
- "10000",
477
- ]
478
- return subprocess.Popen(controller_command)
479
-
480
- def start_worker():
481
- logger.info(f"Starting the model worker")
482
- worker_command = [
483
- "python",
484
- "-m",
485
- "model_worker",
486
- "--host",
487
- "0.0.0.0",
488
- "--controller",
489
- "http://localhost:10000",
490
- "--load-bf16",
491
- "--model-name",
492
- "llava-rlhf-13b-v1.5-336",
493
- "--model-path",
494
- "daviddaytw/Dr-LLaVA-sft",
495
- "--lora-path",
496
- "daviddaytw/Dr-LLaVA-lora-adapter",
497
- ]
498
- return subprocess.Popen(worker_command)
499
-
500
  if __name__ == "__main__":
501
  parser = argparse.ArgumentParser()
502
  parser.add_argument("--host", type=str, default="0.0.0.0")
503
  parser.add_argument("--port", type=int)
504
- parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
505
  parser.add_argument("--concurrency-count", type=int, default=16)
506
- parser.add_argument(
507
- "--model-list-mode", type=str, default="reload", choices=["once", "reload"]
508
- )
509
  parser.add_argument("--share", action="store_true")
510
- parser.add_argument("--moderate", action="store_true")
511
- parser.add_argument("--embed", action="store_true")
512
  args = parser.parse_args()
513
  logger.info(f"args: {args}")
514
 
515
- controller_proc = start_controller()
516
- worker_proc = start_worker()
517
-
518
- # Wait for worker and controller to start
519
- time.sleep(60)
520
-
521
- models = get_model_list()
522
-
523
- logger.info(args)
524
- demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
525
  demo.queue(
526
  api_open=False
527
  ).launch(
 
3
  import hashlib
4
  import json
5
  import os
 
 
6
  import time
7
 
8
  import gradio as gr
 
11
  from constants import LOGDIR
12
  from conversation import (default_conversation, conv_templates,
13
  SeparatorStyle)
14
+ from utils import (build_logger, server_error_msg)
 
15
 
16
 
17
  logger = build_logger("gradio_web_server", "gradio_web_server.log")
18
 
19
+ from model_worker import ModelWorker
20
 
21
  no_change_btn = gr.Button()
22
  enable_btn = gr.Button(interactive=True)
23
  disable_btn = gr.Button(interactive=False)
24
 
 
 
 
 
25
 
26
 
27
  def get_conv_log_filename():
 
29
  name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
30
  return name
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  get_window_url_params = """
33
  function() {
34
  const params = new URLSearchParams(window.location.search);
 
42
  def load_demo(url_params, request: gr.Request):
43
  logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
44
 
45
+ global worker
46
  dropdown_update = gr.Dropdown(visible=True)
47
+ worker = ModelWorker(model_path, None, model_name, True, lora_path)
 
 
 
 
 
 
 
48
 
 
 
 
49
  state = default_conversation.copy()
 
 
 
 
50
  return state, dropdown_update
51
 
52
 
 
101
  if len(text) <= 0 and image is None:
102
  state.skip_next = True
103
  return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
 
 
 
 
 
 
104
 
105
  text = text[:1536] # Hard cut-off
106
  if image is not None:
 
115
  state.skip_next = False
116
  return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
117
 
 
118
  def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
119
  logger.info(f"http_bot. ip: {request.client.host}")
120
  start_tstamp = time.time()
 
166
  new_state.append_message(new_state.roles[1], None)
167
  state = new_state
168
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # Construct prompt
171
  prompt = state.get_prompt()
 
198
 
199
  try:
200
  # Stream output
201
+ for chunk in worker.generate_stream_gate(pload):
 
 
202
  if chunk:
203
  data = json.loads(chunk.decode())
204
  if data["error_code"] == 0:
 
260
 
261
  """
262
 
263
+ def build_demo(cur_dir=None, concurrency_count=10):
264
  textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
265
  with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
266
  state = gr.State()
267
 
268
+ gr.Markdown(title_markdown)
 
 
269
  with gr.Row():
270
  with gr.Column(scale=2):
271
  # add a description
 
327
  regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
328
  clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
329
 
330
+ gr.Markdown(tos_markdown)
331
+ gr.Markdown(learn_more_markdown)
 
332
  url_params = gr.JSON(visible=False)
333
 
334
  # Register listeners
 
389
  [state, chatbot] + btn_list,
390
  concurrency_limit=concurrency_count
391
  )
392
+
393
+ demo.load(
394
+ load_demo,
395
+ [url_params],
396
+ [state, model_selector],
397
+ js=get_window_url_params
398
+ )
 
 
 
 
 
 
 
 
 
 
 
399
  return demo
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  if __name__ == "__main__":
402
  parser = argparse.ArgumentParser()
403
  parser.add_argument("--host", type=str, default="0.0.0.0")
404
  parser.add_argument("--port", type=int)
 
405
  parser.add_argument("--concurrency-count", type=int, default=16)
 
 
 
406
  parser.add_argument("--share", action="store_true")
 
 
407
  args = parser.parse_args()
408
  logger.info(f"args: {args}")
409
 
410
+ models = ['llava-rlhf-13b-v1.5-336']
411
+ model_path = 'daviddaytw/Dr-LLaVA-sft'
412
+ model_name = 'llava-rlhf-13b-v1.5-336'
413
+ lora_path = 'daviddaytw/Dr-LLaVA-lora-adapter'
414
+ demo = build_demo(concurrency_count=args.concurrency_count)
 
 
 
 
 
415
  demo.queue(
416
  api_open=False
417
  ).launch(
model_worker.py CHANGED
@@ -1,26 +1,14 @@
1
  """
2
  A model worker executes the model.
3
  """
4
- import argparse
5
- import asyncio
6
  import json
7
- import time
8
- import threading
9
  import uuid
10
-
11
- from fastapi import FastAPI, Request, BackgroundTasks
12
- from fastapi.responses import StreamingResponse
13
- import requests
14
  import torch
15
- import uvicorn
16
- from functools import partial
17
  import spaces
18
 
19
  from peft import PeftModel
20
 
21
- from llava.constants import WORKER_HEART_BEAT_INTERVAL
22
- from llava.utils import (build_logger, server_error_msg,
23
- pretty_print_semaphore)
24
  from model_builder import load_pretrained_model
25
  from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
26
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
@@ -37,20 +25,8 @@ global_counter = 0
37
  model_semaphore = None
38
 
39
 
40
- def heart_beat_worker(controller):
41
-
42
- while True:
43
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
44
- controller.send_heart_beat()
45
-
46
-
47
  class ModelWorker:
48
- def __init__(self, controller_addr, worker_addr,
49
- worker_id, no_register,
50
- model_path, model_base, model_name,
51
- load_8bit, load_4bit, load_bf16, lora_path):
52
- self.controller_addr = controller_addr
53
- self.worker_addr = worker_addr
54
  self.worker_id = worker_id
55
  if model_path.endswith("/"):
56
  model_path = model_path[:-1]
@@ -65,7 +41,7 @@ class ModelWorker:
65
 
66
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
67
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
68
- model_path, model_base, self.model_name, load_8bit, load_4bit, load_bf16=load_bf16)
69
  self.is_multimodal = 'llava' in self.model_name.lower()
70
  self.load_bf16 = load_bf16
71
 
@@ -76,59 +52,7 @@ class ModelWorker:
76
  torch_device='cpu',
77
  device_map="cpu",
78
  )
79
-
80
- if not no_register:
81
- self.register_to_controller()
82
- self.heart_beat_thread = threading.Thread(
83
- target=heart_beat_worker, args=(self,))
84
- self.heart_beat_thread.start()
85
-
86
- def register_to_controller(self):
87
- logger.info("Register to controller")
88
-
89
- url = self.controller_addr + "/register_worker"
90
- data = {
91
- "worker_name": self.worker_addr,
92
- "check_heart_beat": True,
93
- "worker_status": self.get_status()
94
- }
95
- r = requests.post(url, json=data)
96
- assert r.status_code == 200
97
-
98
- def send_heart_beat(self):
99
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
100
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
101
- f"global_counter: {global_counter}")
102
-
103
- url = self.controller_addr + "/receive_heart_beat"
104
-
105
- while True:
106
- try:
107
- ret = requests.post(url, json={
108
- "worker_name": self.worker_addr,
109
- "queue_length": self.get_queue_length()}, timeout=5)
110
- exist = ret.json()["exist"]
111
- break
112
- except requests.exceptions.RequestException as e:
113
- logger.error(f"heart beat error: {e}")
114
- time.sleep(5)
115
-
116
- if not exist:
117
- self.register_to_controller()
118
-
119
- def get_queue_length(self):
120
- if model_semaphore is None:
121
- return 0
122
- else:
123
- return args.limit_model_concurrency - model_semaphore._value + (len(
124
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
125
-
126
- def get_status(self):
127
- return {
128
- "model_names": [self.model_name],
129
- "speed": 1,
130
- "queue_length": self.get_queue_length(),
131
- }
132
 
133
  @spaces.GPU
134
  def generate_stream(self, params):
@@ -232,71 +156,7 @@ class ModelWorker:
232
  }
233
  yield json.dumps(ret).encode() + b"\0"
234
 
235
-
236
- app = FastAPI()
237
-
238
-
239
  def release_model_semaphore(fn=None):
240
  model_semaphore.release()
241
  if fn is not None:
242
  fn()
243
-
244
-
245
- @app.post("/worker_generate_stream")
246
- async def generate_stream(request: Request):
247
- global model_semaphore, global_counter
248
- global_counter += 1
249
- params = await request.json()
250
-
251
- if model_semaphore is None:
252
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
253
- await model_semaphore.acquire()
254
- worker.send_heart_beat()
255
- generator = worker.generate_stream_gate(params)
256
- background_tasks = BackgroundTasks()
257
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
258
- return StreamingResponse(generator, background=background_tasks)
259
-
260
-
261
- @app.post("/worker_get_status")
262
- async def get_status(request: Request):
263
- return worker.get_status()
264
-
265
-
266
- if __name__ == "__main__":
267
- parser = argparse.ArgumentParser()
268
- parser.add_argument("--host", type=str, default="localhost")
269
- parser.add_argument("--port", type=int, default=21002)
270
- parser.add_argument("--worker-address", type=str,
271
- default="http://localhost:21002")
272
- parser.add_argument("--controller-address", type=str,
273
- default="http://localhost:21001")
274
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
275
- parser.add_argument("--model-base", type=str, default=None)
276
- parser.add_argument("--model-name", type=str)
277
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
278
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
279
- parser.add_argument("--stream-interval", type=int, default=1)
280
- parser.add_argument("--no-register", action="store_true")
281
- parser.add_argument("--load-8bit", action="store_true")
282
- parser.add_argument("--load-4bit", action="store_true")
283
- parser.add_argument("--load-bf16", action="store_true")
284
- parser.add_argument("--lora-path", type=str, default=None)
285
- args = parser.parse_args()
286
- logger.info(f"args: {args}")
287
-
288
- if args.multi_modal:
289
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
290
-
291
- worker = ModelWorker(args.controller_address,
292
- args.worker_address,
293
- worker_id,
294
- args.no_register,
295
- args.model_path,
296
- args.model_base,
297
- args.model_name,
298
- args.load_8bit,
299
- args.load_4bit,
300
- args.load_bf16,
301
- args.lora_path)
302
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
1
  """
2
  A model worker executes the model.
3
  """
 
 
4
  import json
 
 
5
  import uuid
 
 
 
 
6
  import torch
 
 
7
  import spaces
8
 
9
  from peft import PeftModel
10
 
11
+ from llava.utils import (build_logger, server_error_msg)
 
 
12
  from model_builder import load_pretrained_model
13
  from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
14
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
 
25
  model_semaphore = None
26
 
27
 
 
 
 
 
 
 
 
28
  class ModelWorker:
29
+ def __init__(self, model_path, model_base, model_name, load_bf16, lora_path):
 
 
 
 
 
30
  self.worker_id = worker_id
31
  if model_path.endswith("/"):
32
  model_path = model_path[:-1]
 
41
 
42
  logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
43
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
44
+ model_path, model_base, self.model_name, False, False, load_bf16=load_bf16)
45
  self.is_multimodal = 'llava' in self.model_name.lower()
46
  self.load_bf16 = load_bf16
47
 
 
52
  torch_device='cpu',
53
  device_map="cpu",
54
  )
55
+ self.model.to("cuda:0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  @spaces.GPU
58
  def generate_stream(self, params):
 
156
  }
157
  yield json.dumps(ret).encode() + b"\0"
158
 
 
 
 
 
159
  def release_model_semaphore(fn=None):
160
  model_semaphore.release()
161
  if fn is not None:
162
  fn()