update gradio_web_server.py and model_worker.py
Browse files- gradio_web_server.py +6 -1
- model_worker.py +28 -17
gradio_web_server.py
CHANGED
@@ -818,7 +818,7 @@ if __name__ == "__main__":
|
|
818 |
parser = argparse.ArgumentParser()
|
819 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
820 |
parser.add_argument("--port", type=int, default=11000)
|
821 |
-
parser.add_argument("--controller-url", type=str, default=
|
822 |
parser.add_argument("--concurrency-count", type=int, default=10)
|
823 |
parser.add_argument(
|
824 |
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
@@ -829,6 +829,11 @@ if __name__ == "__main__":
|
|
829 |
parser.add_argument("--embed", action="store_true")
|
830 |
args = parser.parse_args()
|
831 |
logger.info(f"args: {args}")
|
|
|
|
|
|
|
|
|
|
|
832 |
|
833 |
models = get_model_list()
|
834 |
|
|
|
818 |
parser = argparse.ArgumentParser()
|
819 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
820 |
parser.add_argument("--port", type=int, default=11000)
|
821 |
+
parser.add_argument("--controller-url", type=str, default=None)
|
822 |
parser.add_argument("--concurrency-count", type=int, default=10)
|
823 |
parser.add_argument(
|
824 |
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
|
|
829 |
parser.add_argument("--embed", action="store_true")
|
830 |
args = parser.parse_args()
|
831 |
logger.info(f"args: {args}")
|
832 |
+
if not args.controller_url:
|
833 |
+
args.controller_url = os.environ.get("CONTROLLER_URL", None)
|
834 |
+
|
835 |
+
if not args.controller_url:
|
836 |
+
raise ValueError("controller-url is required.")
|
837 |
|
838 |
models = get_model_list()
|
839 |
|
model_worker.py
CHANGED
@@ -160,6 +160,25 @@ def split_model(model_name):
|
|
160 |
return device_map
|
161 |
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
class ModelWorker:
|
164 |
def __init__(
|
165 |
self,
|
@@ -325,8 +344,6 @@ class ModelWorker:
|
|
325 |
"queue_length": self.get_queue_length(),
|
326 |
}
|
327 |
|
328 |
-
# @torch.inference_mode()
|
329 |
-
@spaces.GPU(duration=120)
|
330 |
def generate_stream(self, params):
|
331 |
system_message = params["prompt"][0]["content"]
|
332 |
send_messages = params["prompt"][1:]
|
@@ -428,20 +445,14 @@ class ModelWorker:
|
|
428 |
streamer=streamer,
|
429 |
)
|
430 |
logger.info(f"Generation config: {generation_config}")
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
history=history,
|
440 |
-
return_history=False,
|
441 |
-
generation_config=generation_config,
|
442 |
-
),
|
443 |
-
)
|
444 |
-
thread.start()
|
445 |
|
446 |
generated_text = ""
|
447 |
for new_text in streamer:
|
@@ -541,4 +552,4 @@ if __name__ == "__main__":
|
|
541 |
args.load_8bit,
|
542 |
args.device,
|
543 |
)
|
544 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info"
|
|
|
160 |
return device_map
|
161 |
|
162 |
|
163 |
+
@spaces.GPU(duration=120)
|
164 |
+
def multi_thread_infer(
|
165 |
+
model, tokenizer, pixel_values, question, history, generation_config
|
166 |
+
):
|
167 |
+
with torch.no_grad():
|
168 |
+
thread = Thread(
|
169 |
+
target=model.chat,
|
170 |
+
kwargs=dict(
|
171 |
+
tokenizer=tokenizer,
|
172 |
+
pixel_values=pixel_values,
|
173 |
+
question=question,
|
174 |
+
history=history,
|
175 |
+
return_history=False,
|
176 |
+
generation_config=generation_config,
|
177 |
+
),
|
178 |
+
)
|
179 |
+
thread.start()
|
180 |
+
|
181 |
+
|
182 |
class ModelWorker:
|
183 |
def __init__(
|
184 |
self,
|
|
|
344 |
"queue_length": self.get_queue_length(),
|
345 |
}
|
346 |
|
|
|
|
|
347 |
def generate_stream(self, params):
|
348 |
system_message = params["prompt"][0]["content"]
|
349 |
send_messages = params["prompt"][1:]
|
|
|
445 |
streamer=streamer,
|
446 |
)
|
447 |
logger.info(f"Generation config: {generation_config}")
|
448 |
+
multi_thread_infer(
|
449 |
+
self.model,
|
450 |
+
self.tokenizer,
|
451 |
+
pixel_values,
|
452 |
+
question,
|
453 |
+
history,
|
454 |
+
generation_config,
|
455 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
generated_text = ""
|
458 |
for new_text in streamer:
|
|
|
552 |
args.load_8bit,
|
553 |
args.device,
|
554 |
)
|
555 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|