Spaces:
Runtime error
Runtime error
Delete mplug_docowl/serve
Browse files- mplug_docowl/serve/__init__.py +0 -0
- mplug_docowl/serve/cli.py +0 -120
- mplug_docowl/serve/controller.py +0 -298
- mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg +0 -0
- mplug_docowl/serve/examples/extreme_ironing.jpg +0 -0
- mplug_docowl/serve/gradio_web_server.py +0 -460
- mplug_docowl/serve/model_worker.py +0 -342
- mplug_docowl/serve/model_worker_bak.py +0 -278
- mplug_docowl/serve/register_workers.py +0 -26
mplug_docowl/serve/__init__.py
DELETED
File without changes
|
mplug_docowl/serve/cli.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
-
from mplug_owl2.conversation import conv_templates, SeparatorStyle
|
6 |
-
from mplug_owl2.model.builder import load_pretrained_model
|
7 |
-
from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
-
|
9 |
-
from PIL import Image
|
10 |
-
|
11 |
-
import requests
|
12 |
-
from PIL import Image
|
13 |
-
from io import BytesIO
|
14 |
-
from transformers import TextStreamer
|
15 |
-
|
16 |
-
|
17 |
-
def disable_torch_init():
|
18 |
-
"""
|
19 |
-
Disable the redundant torch default initialization to accelerate model creation.
|
20 |
-
"""
|
21 |
-
import torch
|
22 |
-
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
23 |
-
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
24 |
-
|
25 |
-
|
26 |
-
def load_image(image_file):
|
27 |
-
if image_file.startswith('http://') or image_file.startswith('https://'):
|
28 |
-
response = requests.get(image_file)
|
29 |
-
image = Image.open(BytesIO(response.content)).convert('RGB')
|
30 |
-
else:
|
31 |
-
image = Image.open(image_file).convert('RGB')
|
32 |
-
return image
|
33 |
-
|
34 |
-
|
35 |
-
def main(args):
|
36 |
-
# Model
|
37 |
-
disable_torch_init()
|
38 |
-
|
39 |
-
model_name = get_model_name_from_path(args.model_path)
|
40 |
-
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
41 |
-
|
42 |
-
conv_mode = "mplug_owl2"
|
43 |
-
|
44 |
-
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
45 |
-
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
46 |
-
else:
|
47 |
-
args.conv_mode = conv_mode
|
48 |
-
|
49 |
-
conv = conv_templates[args.conv_mode].copy()
|
50 |
-
roles = conv.roles
|
51 |
-
|
52 |
-
image = load_image(args.image_file)
|
53 |
-
# Similar operation in model_worker.py
|
54 |
-
image_tensor = process_images([image], image_processor, args)
|
55 |
-
if type(image_tensor) is list:
|
56 |
-
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
57 |
-
else:
|
58 |
-
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
59 |
-
|
60 |
-
while True:
|
61 |
-
try:
|
62 |
-
inp = input(f"{roles[0]}: ")
|
63 |
-
except EOFError:
|
64 |
-
inp = ""
|
65 |
-
if not inp:
|
66 |
-
print("exit...")
|
67 |
-
break
|
68 |
-
|
69 |
-
print(f"{roles[1]}: ", end="")
|
70 |
-
|
71 |
-
if image is not None:
|
72 |
-
# first message
|
73 |
-
inp = DEFAULT_IMAGE_TOKEN + inp
|
74 |
-
conv.append_message(conv.roles[0], inp)
|
75 |
-
image = None
|
76 |
-
else:
|
77 |
-
# later messages
|
78 |
-
conv.append_message(conv.roles[0], inp)
|
79 |
-
conv.append_message(conv.roles[1], None)
|
80 |
-
prompt = conv.get_prompt()
|
81 |
-
|
82 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
83 |
-
stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
|
84 |
-
keywords = [stop_str]
|
85 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
86 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
87 |
-
|
88 |
-
with torch.inference_mode():
|
89 |
-
output_ids = model.generate(
|
90 |
-
input_ids,
|
91 |
-
images=image_tensor,
|
92 |
-
do_sample=True,
|
93 |
-
temperature=args.temperature,
|
94 |
-
max_new_tokens=args.max_new_tokens,
|
95 |
-
streamer=streamer,
|
96 |
-
use_cache=True,
|
97 |
-
stopping_criteria=[stopping_criteria])
|
98 |
-
|
99 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
100 |
-
conv.messages[-1][-1] = outputs
|
101 |
-
|
102 |
-
if args.debug:
|
103 |
-
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
104 |
-
|
105 |
-
|
106 |
-
if __name__ == "__main__":
|
107 |
-
parser = argparse.ArgumentParser()
|
108 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
109 |
-
parser.add_argument("--model-base", type=str, default=None)
|
110 |
-
parser.add_argument("--image-file", type=str, required=True)
|
111 |
-
parser.add_argument("--device", type=str, default="cuda")
|
112 |
-
parser.add_argument("--conv-mode", type=str, default=None)
|
113 |
-
parser.add_argument("--temperature", type=float, default=0.2)
|
114 |
-
parser.add_argument("--max-new-tokens", type=int, default=512)
|
115 |
-
parser.add_argument("--load-8bit", action="store_true")
|
116 |
-
parser.add_argument("--load-4bit", action="store_true")
|
117 |
-
parser.add_argument("--debug", action="store_true")
|
118 |
-
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
119 |
-
args = parser.parse_args()
|
120 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mplug_docowl/serve/controller.py
DELETED
@@ -1,298 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A controller manages distributed workers.
|
3 |
-
It sends worker addresses to clients.
|
4 |
-
"""
|
5 |
-
import argparse
|
6 |
-
import asyncio
|
7 |
-
import dataclasses
|
8 |
-
from enum import Enum, auto
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import time
|
12 |
-
from typing import List, Union
|
13 |
-
import threading
|
14 |
-
|
15 |
-
from fastapi import FastAPI, Request
|
16 |
-
from fastapi.responses import StreamingResponse
|
17 |
-
import numpy as np
|
18 |
-
import requests
|
19 |
-
import uvicorn
|
20 |
-
|
21 |
-
from mplug_owl2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
-
from mplug_owl2.utils import build_logger, server_error_msg
|
23 |
-
|
24 |
-
|
25 |
-
logger = build_logger("controller", "controller.log")
|
26 |
-
|
27 |
-
|
28 |
-
class DispatchMethod(Enum):
|
29 |
-
LOTTERY = auto()
|
30 |
-
SHORTEST_QUEUE = auto()
|
31 |
-
|
32 |
-
@classmethod
|
33 |
-
def from_str(cls, name):
|
34 |
-
if name == "lottery":
|
35 |
-
return cls.LOTTERY
|
36 |
-
elif name == "shortest_queue":
|
37 |
-
return cls.SHORTEST_QUEUE
|
38 |
-
else:
|
39 |
-
raise ValueError(f"Invalid dispatch method")
|
40 |
-
|
41 |
-
|
42 |
-
@dataclasses.dataclass
|
43 |
-
class WorkerInfo:
|
44 |
-
model_names: List[str]
|
45 |
-
speed: int
|
46 |
-
queue_length: int
|
47 |
-
check_heart_beat: bool
|
48 |
-
last_heart_beat: str
|
49 |
-
|
50 |
-
|
51 |
-
def heart_beat_controller(controller):
|
52 |
-
while True:
|
53 |
-
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
54 |
-
controller.remove_stable_workers_by_expiration()
|
55 |
-
|
56 |
-
|
57 |
-
class Controller:
|
58 |
-
def __init__(self, dispatch_method: str):
|
59 |
-
# Dict[str -> WorkerInfo]
|
60 |
-
self.worker_info = {}
|
61 |
-
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
62 |
-
|
63 |
-
self.heart_beat_thread = threading.Thread(
|
64 |
-
target=heart_beat_controller, args=(self,))
|
65 |
-
self.heart_beat_thread.start()
|
66 |
-
|
67 |
-
logger.info("Init controller")
|
68 |
-
|
69 |
-
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
70 |
-
worker_status: dict):
|
71 |
-
if worker_name not in self.worker_info:
|
72 |
-
logger.info(f"Register a new worker: {worker_name}")
|
73 |
-
else:
|
74 |
-
logger.info(f"Register an existing worker: {worker_name}")
|
75 |
-
|
76 |
-
if not worker_status:
|
77 |
-
worker_status = self.get_worker_status(worker_name)
|
78 |
-
if not worker_status:
|
79 |
-
return False
|
80 |
-
|
81 |
-
self.worker_info[worker_name] = WorkerInfo(
|
82 |
-
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
83 |
-
check_heart_beat, time.time())
|
84 |
-
|
85 |
-
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
-
return True
|
87 |
-
|
88 |
-
def get_worker_status(self, worker_name: str):
|
89 |
-
try:
|
90 |
-
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
91 |
-
except requests.exceptions.RequestException as e:
|
92 |
-
logger.error(f"Get status fails: {worker_name}, {e}")
|
93 |
-
return None
|
94 |
-
|
95 |
-
if r.status_code != 200:
|
96 |
-
logger.error(f"Get status fails: {worker_name}, {r}")
|
97 |
-
return None
|
98 |
-
|
99 |
-
return r.json()
|
100 |
-
|
101 |
-
def remove_worker(self, worker_name: str):
|
102 |
-
del self.worker_info[worker_name]
|
103 |
-
|
104 |
-
def refresh_all_workers(self):
|
105 |
-
old_info = dict(self.worker_info)
|
106 |
-
self.worker_info = {}
|
107 |
-
|
108 |
-
for w_name, w_info in old_info.items():
|
109 |
-
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
110 |
-
logger.info(f"Remove stale worker: {w_name}")
|
111 |
-
|
112 |
-
def list_models(self):
|
113 |
-
model_names = set()
|
114 |
-
|
115 |
-
for w_name, w_info in self.worker_info.items():
|
116 |
-
model_names.update(w_info.model_names)
|
117 |
-
|
118 |
-
return list(model_names)
|
119 |
-
|
120 |
-
def get_worker_address(self, model_name: str):
|
121 |
-
if self.dispatch_method == DispatchMethod.LOTTERY:
|
122 |
-
worker_names = []
|
123 |
-
worker_speeds = []
|
124 |
-
for w_name, w_info in self.worker_info.items():
|
125 |
-
if model_name in w_info.model_names:
|
126 |
-
worker_names.append(w_name)
|
127 |
-
worker_speeds.append(w_info.speed)
|
128 |
-
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
129 |
-
norm = np.sum(worker_speeds)
|
130 |
-
if norm < 1e-4:
|
131 |
-
return ""
|
132 |
-
worker_speeds = worker_speeds / norm
|
133 |
-
if True: # Directly return address
|
134 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
135 |
-
p=worker_speeds)
|
136 |
-
worker_name = worker_names[pt]
|
137 |
-
return worker_name
|
138 |
-
|
139 |
-
# Check status before returning
|
140 |
-
while True:
|
141 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
142 |
-
p=worker_speeds)
|
143 |
-
worker_name = worker_names[pt]
|
144 |
-
|
145 |
-
if self.get_worker_status(worker_name):
|
146 |
-
break
|
147 |
-
else:
|
148 |
-
self.remove_worker(worker_name)
|
149 |
-
worker_speeds[pt] = 0
|
150 |
-
norm = np.sum(worker_speeds)
|
151 |
-
if norm < 1e-4:
|
152 |
-
return ""
|
153 |
-
worker_speeds = worker_speeds / norm
|
154 |
-
continue
|
155 |
-
return worker_name
|
156 |
-
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
157 |
-
worker_names = []
|
158 |
-
worker_qlen = []
|
159 |
-
for w_name, w_info in self.worker_info.items():
|
160 |
-
if model_name in w_info.model_names:
|
161 |
-
worker_names.append(w_name)
|
162 |
-
worker_qlen.append(w_info.queue_length / w_info.speed)
|
163 |
-
if len(worker_names) == 0:
|
164 |
-
return ""
|
165 |
-
min_index = np.argmin(worker_qlen)
|
166 |
-
w_name = worker_names[min_index]
|
167 |
-
self.worker_info[w_name].queue_length += 1
|
168 |
-
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
169 |
-
return w_name
|
170 |
-
else:
|
171 |
-
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
172 |
-
|
173 |
-
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
174 |
-
if worker_name not in self.worker_info:
|
175 |
-
logger.info(f"Receive unknown heart beat. {worker_name}")
|
176 |
-
return False
|
177 |
-
|
178 |
-
self.worker_info[worker_name].queue_length = queue_length
|
179 |
-
self.worker_info[worker_name].last_heart_beat = time.time()
|
180 |
-
logger.info(f"Receive heart beat. {worker_name}")
|
181 |
-
return True
|
182 |
-
|
183 |
-
def remove_stable_workers_by_expiration(self):
|
184 |
-
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
185 |
-
to_delete = []
|
186 |
-
for worker_name, w_info in self.worker_info.items():
|
187 |
-
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
188 |
-
to_delete.append(worker_name)
|
189 |
-
|
190 |
-
for worker_name in to_delete:
|
191 |
-
self.remove_worker(worker_name)
|
192 |
-
|
193 |
-
def worker_api_generate_stream(self, params):
|
194 |
-
worker_addr = self.get_worker_address(params["model"])
|
195 |
-
if not worker_addr:
|
196 |
-
logger.info(f"no worker: {params['model']}")
|
197 |
-
ret = {
|
198 |
-
"text": server_error_msg,
|
199 |
-
"error_code": 2,
|
200 |
-
}
|
201 |
-
yield json.dumps(ret).encode() + b"\0"
|
202 |
-
|
203 |
-
try:
|
204 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
205 |
-
json=params, stream=True, timeout=5)
|
206 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
207 |
-
if chunk:
|
208 |
-
yield chunk + b"\0"
|
209 |
-
except requests.exceptions.RequestException as e:
|
210 |
-
logger.info(f"worker timeout: {worker_addr}")
|
211 |
-
ret = {
|
212 |
-
"text": server_error_msg,
|
213 |
-
"error_code": 3,
|
214 |
-
}
|
215 |
-
yield json.dumps(ret).encode() + b"\0"
|
216 |
-
|
217 |
-
|
218 |
-
# Let the controller act as a worker to achieve hierarchical
|
219 |
-
# management. This can be used to connect isolated sub networks.
|
220 |
-
def worker_api_get_status(self):
|
221 |
-
model_names = set()
|
222 |
-
speed = 0
|
223 |
-
queue_length = 0
|
224 |
-
|
225 |
-
for w_name in self.worker_info:
|
226 |
-
worker_status = self.get_worker_status(w_name)
|
227 |
-
if worker_status is not None:
|
228 |
-
model_names.update(worker_status["model_names"])
|
229 |
-
speed += worker_status["speed"]
|
230 |
-
queue_length += worker_status["queue_length"]
|
231 |
-
|
232 |
-
return {
|
233 |
-
"model_names": list(model_names),
|
234 |
-
"speed": speed,
|
235 |
-
"queue_length": queue_length,
|
236 |
-
}
|
237 |
-
|
238 |
-
|
239 |
-
app = FastAPI()
|
240 |
-
|
241 |
-
|
242 |
-
@app.post("/register_worker")
|
243 |
-
async def register_worker(request: Request):
|
244 |
-
data = await request.json()
|
245 |
-
controller.register_worker(
|
246 |
-
data["worker_name"], data["check_heart_beat"],
|
247 |
-
data.get("worker_status", None))
|
248 |
-
|
249 |
-
|
250 |
-
@app.post("/refresh_all_workers")
|
251 |
-
async def refresh_all_workers():
|
252 |
-
models = controller.refresh_all_workers()
|
253 |
-
|
254 |
-
|
255 |
-
@app.post("/list_models")
|
256 |
-
async def list_models():
|
257 |
-
models = controller.list_models()
|
258 |
-
return {"models": models}
|
259 |
-
|
260 |
-
|
261 |
-
@app.post("/get_worker_address")
|
262 |
-
async def get_worker_address(request: Request):
|
263 |
-
data = await request.json()
|
264 |
-
addr = controller.get_worker_address(data["model"])
|
265 |
-
return {"address": addr}
|
266 |
-
|
267 |
-
|
268 |
-
@app.post("/receive_heart_beat")
|
269 |
-
async def receive_heart_beat(request: Request):
|
270 |
-
data = await request.json()
|
271 |
-
exist = controller.receive_heart_beat(
|
272 |
-
data["worker_name"], data["queue_length"])
|
273 |
-
return {"exist": exist}
|
274 |
-
|
275 |
-
|
276 |
-
@app.post("/worker_generate_stream")
|
277 |
-
async def worker_api_generate_stream(request: Request):
|
278 |
-
params = await request.json()
|
279 |
-
generator = controller.worker_api_generate_stream(params)
|
280 |
-
return StreamingResponse(generator)
|
281 |
-
|
282 |
-
|
283 |
-
@app.post("/worker_get_status")
|
284 |
-
async def worker_api_get_status(request: Request):
|
285 |
-
return controller.worker_api_get_status()
|
286 |
-
|
287 |
-
|
288 |
-
if __name__ == "__main__":
|
289 |
-
parser = argparse.ArgumentParser()
|
290 |
-
parser.add_argument("--host", type=str, default="localhost")
|
291 |
-
parser.add_argument("--port", type=int, default=21001)
|
292 |
-
parser.add_argument("--dispatch-method", type=str, choices=[
|
293 |
-
"lottery", "shortest_queue"], default="shortest_queue")
|
294 |
-
args = parser.parse_args()
|
295 |
-
logger.info(f"args: {args}")
|
296 |
-
|
297 |
-
controller = Controller(args.dispatch_method)
|
298 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg
DELETED
Binary file (18.9 kB)
|
|
mplug_docowl/serve/examples/extreme_ironing.jpg
DELETED
Binary file (62.6 kB)
|
|
mplug_docowl/serve/gradio_web_server.py
DELETED
@@ -1,460 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import datetime
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import time
|
6 |
-
|
7 |
-
import gradio as gr
|
8 |
-
import requests
|
9 |
-
|
10 |
-
from mplug_owl2.conversation import (default_conversation, conv_templates,
|
11 |
-
SeparatorStyle)
|
12 |
-
from mplug_owl2.constants import LOGDIR
|
13 |
-
from mplug_owl2.utils import (build_logger, server_error_msg,
|
14 |
-
violates_moderation, moderation_msg)
|
15 |
-
import hashlib
|
16 |
-
|
17 |
-
|
18 |
-
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
19 |
-
|
20 |
-
headers = {"User-Agent": "mPLUG-Owl2 Client"}
|
21 |
-
|
22 |
-
no_change_btn = gr.Button.update()
|
23 |
-
enable_btn = gr.Button.update(interactive=True)
|
24 |
-
disable_btn = gr.Button.update(interactive=False)
|
25 |
-
|
26 |
-
priority = {
|
27 |
-
"vicuna-13b": "aaaaaaa",
|
28 |
-
"koala-13b": "aaaaaab",
|
29 |
-
}
|
30 |
-
|
31 |
-
|
32 |
-
def get_conv_log_filename():
|
33 |
-
t = datetime.datetime.now()
|
34 |
-
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
35 |
-
return name
|
36 |
-
|
37 |
-
|
38 |
-
def get_model_list():
|
39 |
-
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
40 |
-
assert ret.status_code == 200
|
41 |
-
ret = requests.post(args.controller_url + "/list_models")
|
42 |
-
models = ret.json()["models"]
|
43 |
-
models.sort(key=lambda x: priority.get(x, x))
|
44 |
-
logger.info(f"Models: {models}")
|
45 |
-
return models
|
46 |
-
|
47 |
-
|
48 |
-
get_window_url_params = """
|
49 |
-
function() {
|
50 |
-
const params = new URLSearchParams(window.location.search);
|
51 |
-
url_params = Object.fromEntries(params);
|
52 |
-
console.log(url_params);
|
53 |
-
return url_params;
|
54 |
-
}
|
55 |
-
"""
|
56 |
-
|
57 |
-
|
58 |
-
def load_demo(url_params, request: gr.Request):
|
59 |
-
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
60 |
-
|
61 |
-
dropdown_update = gr.Dropdown.update(visible=True)
|
62 |
-
if "model" in url_params:
|
63 |
-
model = url_params["model"]
|
64 |
-
if model in models:
|
65 |
-
dropdown_update = gr.Dropdown.update(
|
66 |
-
value=model, visible=True)
|
67 |
-
|
68 |
-
state = default_conversation.copy()
|
69 |
-
return state, dropdown_update
|
70 |
-
|
71 |
-
|
72 |
-
def load_demo_refresh_model_list(request: gr.Request):
|
73 |
-
logger.info(f"load_demo. ip: {request.client.host}")
|
74 |
-
models = get_model_list()
|
75 |
-
state = default_conversation.copy()
|
76 |
-
dropdown_update = gr.Dropdown.update(
|
77 |
-
choices=models,
|
78 |
-
value=models[0] if len(models) > 0 else ""
|
79 |
-
)
|
80 |
-
return state, dropdown_update
|
81 |
-
|
82 |
-
|
83 |
-
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
84 |
-
with open(get_conv_log_filename(), "a") as fout:
|
85 |
-
data = {
|
86 |
-
"tstamp": round(time.time(), 4),
|
87 |
-
"type": vote_type,
|
88 |
-
"model": model_selector,
|
89 |
-
"state": state.dict(),
|
90 |
-
"ip": request.client.host,
|
91 |
-
}
|
92 |
-
fout.write(json.dumps(data) + "\n")
|
93 |
-
|
94 |
-
|
95 |
-
def upvote_last_response(state, model_selector, request: gr.Request):
|
96 |
-
logger.info(f"upvote. ip: {request.client.host}")
|
97 |
-
vote_last_response(state, "upvote", model_selector, request)
|
98 |
-
return ("",) + (disable_btn,) * 3
|
99 |
-
|
100 |
-
|
101 |
-
def downvote_last_response(state, model_selector, request: gr.Request):
|
102 |
-
logger.info(f"downvote. ip: {request.client.host}")
|
103 |
-
vote_last_response(state, "downvote", model_selector, request)
|
104 |
-
return ("",) + (disable_btn,) * 3
|
105 |
-
|
106 |
-
|
107 |
-
def flag_last_response(state, model_selector, request: gr.Request):
|
108 |
-
logger.info(f"flag. ip: {request.client.host}")
|
109 |
-
vote_last_response(state, "flag", model_selector, request)
|
110 |
-
return ("",) + (disable_btn,) * 3
|
111 |
-
|
112 |
-
|
113 |
-
def regenerate(state, image_process_mode, request: gr.Request):
|
114 |
-
logger.info(f"regenerate. ip: {request.client.host}")
|
115 |
-
state.messages[-1][-1] = None
|
116 |
-
prev_human_msg = state.messages[-2]
|
117 |
-
if type(prev_human_msg[1]) in (tuple, list):
|
118 |
-
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
119 |
-
state.skip_next = False
|
120 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
121 |
-
|
122 |
-
|
123 |
-
def clear_history(request: gr.Request):
|
124 |
-
logger.info(f"clear_history. ip: {request.client.host}")
|
125 |
-
state = default_conversation.copy()
|
126 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
127 |
-
|
128 |
-
|
129 |
-
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
130 |
-
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
131 |
-
if len(text) <= 0 and image is None:
|
132 |
-
state.skip_next = True
|
133 |
-
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
134 |
-
if args.moderate:
|
135 |
-
flagged = violates_moderation(text)
|
136 |
-
if flagged:
|
137 |
-
state.skip_next = True
|
138 |
-
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
|
139 |
-
no_change_btn,) * 5
|
140 |
-
|
141 |
-
text = text[:1536] # Hard cut-off
|
142 |
-
if image is not None:
|
143 |
-
text = text[:1200] # Hard cut-off for images
|
144 |
-
if '<|image|>' not in text:
|
145 |
-
# text = text + '<|image|>'
|
146 |
-
text = '<|image|>' + text
|
147 |
-
text = (text, image, image_process_mode)
|
148 |
-
if len(state.get_images(return_pil=True)) > 0:
|
149 |
-
state = default_conversation.copy()
|
150 |
-
state.append_message(state.roles[0], text)
|
151 |
-
state.append_message(state.roles[1], None)
|
152 |
-
state.skip_next = False
|
153 |
-
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
154 |
-
|
155 |
-
|
156 |
-
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
157 |
-
logger.info(f"http_bot. ip: {request.client.host}")
|
158 |
-
start_tstamp = time.time()
|
159 |
-
model_name = model_selector
|
160 |
-
|
161 |
-
if state.skip_next:
|
162 |
-
# This generate call is skipped due to invalid inputs
|
163 |
-
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
164 |
-
return
|
165 |
-
|
166 |
-
if len(state.messages) == state.offset + 2:
|
167 |
-
# First round of conversation
|
168 |
-
template_name = "mplug_owl2"
|
169 |
-
new_state = conv_templates[template_name].copy()
|
170 |
-
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
171 |
-
new_state.append_message(new_state.roles[1], None)
|
172 |
-
state = new_state
|
173 |
-
|
174 |
-
# Query worker address
|
175 |
-
controller_url = args.controller_url
|
176 |
-
ret = requests.post(controller_url + "/get_worker_address",
|
177 |
-
json={"model": model_name})
|
178 |
-
worker_addr = ret.json()["address"]
|
179 |
-
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
180 |
-
|
181 |
-
# No available worker
|
182 |
-
if worker_addr == "":
|
183 |
-
state.messages[-1][-1] = server_error_msg
|
184 |
-
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
185 |
-
return
|
186 |
-
|
187 |
-
# Construct prompt
|
188 |
-
prompt = state.get_prompt()
|
189 |
-
|
190 |
-
all_images = state.get_images(return_pil=True)
|
191 |
-
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
192 |
-
for image, hash in zip(all_images, all_image_hash):
|
193 |
-
t = datetime.datetime.now()
|
194 |
-
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
|
195 |
-
if not os.path.isfile(filename):
|
196 |
-
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
197 |
-
image.save(filename)
|
198 |
-
|
199 |
-
# Make requests
|
200 |
-
pload = {
|
201 |
-
"model": model_name,
|
202 |
-
"prompt": prompt,
|
203 |
-
"temperature": float(temperature),
|
204 |
-
"top_p": float(top_p),
|
205 |
-
"max_new_tokens": min(int(max_new_tokens), 1536),
|
206 |
-
"stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
|
207 |
-
"images": f'List of {len(state.get_images())} images: {all_image_hash}',
|
208 |
-
}
|
209 |
-
logger.info(f"==== request ====\n{pload}")
|
210 |
-
|
211 |
-
pload['images'] = state.get_images()
|
212 |
-
|
213 |
-
state.messages[-1][-1] = "▌"
|
214 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
215 |
-
|
216 |
-
try:
|
217 |
-
# Stream output
|
218 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
219 |
-
headers=headers, json=pload, stream=True, timeout=10)
|
220 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
221 |
-
if chunk:
|
222 |
-
data = json.loads(chunk.decode())
|
223 |
-
if data["error_code"] == 0:
|
224 |
-
output = data["text"][len(prompt):].strip()
|
225 |
-
state.messages[-1][-1] = output + "▌"
|
226 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
227 |
-
else:
|
228 |
-
output = data["text"] + f" (error_code: {data['error_code']})"
|
229 |
-
state.messages[-1][-1] = output
|
230 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
231 |
-
return
|
232 |
-
time.sleep(0.03)
|
233 |
-
except requests.exceptions.RequestException as e:
|
234 |
-
state.messages[-1][-1] = server_error_msg
|
235 |
-
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
236 |
-
return
|
237 |
-
|
238 |
-
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
239 |
-
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
240 |
-
|
241 |
-
finish_tstamp = time.time()
|
242 |
-
logger.info(f"{output}")
|
243 |
-
|
244 |
-
with open(get_conv_log_filename(), "a") as fout:
|
245 |
-
data = {
|
246 |
-
"tstamp": round(finish_tstamp, 4),
|
247 |
-
"type": "chat",
|
248 |
-
"model": model_name,
|
249 |
-
"start": round(start_tstamp, 4),
|
250 |
-
"finish": round(start_tstamp, 4),
|
251 |
-
"state": state.dict(),
|
252 |
-
"images": all_image_hash,
|
253 |
-
"ip": request.client.host,
|
254 |
-
}
|
255 |
-
fout.write(json.dumps(data) + "\n")
|
256 |
-
|
257 |
-
|
258 |
-
title_markdown = ("""
|
259 |
-
<h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
|
260 |
-
|
261 |
-
<h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
|
262 |
-
|
263 |
-
<h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
|
264 |
-
|
265 |
-
<div align="center">
|
266 |
-
<div style="display:flex; gap: 0.25rem;" align="center">
|
267 |
-
<a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
268 |
-
<a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
|
269 |
-
<a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
|
270 |
-
</div>
|
271 |
-
</div>
|
272 |
-
|
273 |
-
""")
|
274 |
-
|
275 |
-
|
276 |
-
tos_markdown = ("""
|
277 |
-
### Terms of use
|
278 |
-
By using this service, users are required to agree to the following terms:
|
279 |
-
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
280 |
-
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
281 |
-
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
282 |
-
""")
|
283 |
-
|
284 |
-
|
285 |
-
learn_more_markdown = ("""
|
286 |
-
### License
|
287 |
-
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
288 |
-
""")
|
289 |
-
|
290 |
-
block_css = """
|
291 |
-
|
292 |
-
#buttons button {
|
293 |
-
min-width: min(120px,100%);
|
294 |
-
}
|
295 |
-
|
296 |
-
"""
|
297 |
-
|
298 |
-
def build_demo(embed_mode):
|
299 |
-
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
300 |
-
with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
|
301 |
-
state = gr.State()
|
302 |
-
|
303 |
-
if not embed_mode:
|
304 |
-
gr.Markdown(title_markdown)
|
305 |
-
|
306 |
-
with gr.Row():
|
307 |
-
with gr.Column(scale=3):
|
308 |
-
with gr.Row(elem_id="model_selector_row"):
|
309 |
-
model_selector = gr.Dropdown(
|
310 |
-
choices=models,
|
311 |
-
value=models[0] if len(models) > 0 else "",
|
312 |
-
interactive=True,
|
313 |
-
show_label=False,
|
314 |
-
container=False)
|
315 |
-
|
316 |
-
imagebox = gr.Image(type="pil")
|
317 |
-
image_process_mode = gr.Radio(
|
318 |
-
["Crop", "Resize", "Pad", "Default"],
|
319 |
-
value="Default",
|
320 |
-
label="Preprocess for non-square image", visible=False)
|
321 |
-
|
322 |
-
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
323 |
-
gr.Examples(examples=[
|
324 |
-
[f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
|
325 |
-
[f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
|
326 |
-
], inputs=[imagebox, textbox])
|
327 |
-
|
328 |
-
with gr.Accordion("Parameters", open=True) as parameter_row:
|
329 |
-
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
|
330 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
331 |
-
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
332 |
-
|
333 |
-
with gr.Column(scale=8):
|
334 |
-
chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
|
335 |
-
with gr.Row():
|
336 |
-
with gr.Column(scale=8):
|
337 |
-
textbox.render()
|
338 |
-
with gr.Column(scale=1, min_width=50):
|
339 |
-
submit_btn = gr.Button(value="Send", variant="primary")
|
340 |
-
with gr.Row(elem_id="buttons") as button_row:
|
341 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
342 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
343 |
-
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
344 |
-
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
345 |
-
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
346 |
-
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
347 |
-
|
348 |
-
if not embed_mode:
|
349 |
-
gr.Markdown(tos_markdown)
|
350 |
-
gr.Markdown(learn_more_markdown)
|
351 |
-
url_params = gr.JSON(visible=False)
|
352 |
-
|
353 |
-
# Register listeners
|
354 |
-
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
355 |
-
upvote_btn.click(
|
356 |
-
upvote_last_response,
|
357 |
-
[state, model_selector],
|
358 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
359 |
-
queue=False
|
360 |
-
)
|
361 |
-
downvote_btn.click(
|
362 |
-
downvote_last_response,
|
363 |
-
[state, model_selector],
|
364 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
365 |
-
queue=False
|
366 |
-
)
|
367 |
-
flag_btn.click(
|
368 |
-
flag_last_response,
|
369 |
-
[state, model_selector],
|
370 |
-
[textbox, upvote_btn, downvote_btn, flag_btn],
|
371 |
-
queue=False
|
372 |
-
)
|
373 |
-
|
374 |
-
regenerate_btn.click(
|
375 |
-
regenerate,
|
376 |
-
[state, image_process_mode],
|
377 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
378 |
-
queue=False
|
379 |
-
).then(
|
380 |
-
http_bot,
|
381 |
-
[state, model_selector, temperature, top_p, max_output_tokens],
|
382 |
-
[state, chatbot] + btn_list
|
383 |
-
)
|
384 |
-
|
385 |
-
clear_btn.click(
|
386 |
-
clear_history,
|
387 |
-
None,
|
388 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
389 |
-
queue=False
|
390 |
-
)
|
391 |
-
|
392 |
-
textbox.submit(
|
393 |
-
add_text,
|
394 |
-
[state, textbox, imagebox, image_process_mode],
|
395 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
396 |
-
queue=False
|
397 |
-
).then(
|
398 |
-
http_bot,
|
399 |
-
[state, model_selector, temperature, top_p, max_output_tokens],
|
400 |
-
[state, chatbot] + btn_list
|
401 |
-
)
|
402 |
-
|
403 |
-
submit_btn.click(
|
404 |
-
add_text,
|
405 |
-
[state, textbox, imagebox, image_process_mode],
|
406 |
-
[state, chatbot, textbox, imagebox] + btn_list,
|
407 |
-
queue=False
|
408 |
-
).then(
|
409 |
-
http_bot,
|
410 |
-
[state, model_selector, temperature, top_p, max_output_tokens],
|
411 |
-
[state, chatbot] + btn_list
|
412 |
-
)
|
413 |
-
|
414 |
-
if args.model_list_mode == "once":
|
415 |
-
demo.load(
|
416 |
-
load_demo,
|
417 |
-
[url_params],
|
418 |
-
[state, model_selector],
|
419 |
-
_js=get_window_url_params,
|
420 |
-
queue=False
|
421 |
-
)
|
422 |
-
elif args.model_list_mode == "reload":
|
423 |
-
demo.load(
|
424 |
-
load_demo_refresh_model_list,
|
425 |
-
None,
|
426 |
-
[state, model_selector],
|
427 |
-
queue=False
|
428 |
-
)
|
429 |
-
else:
|
430 |
-
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
431 |
-
|
432 |
-
return demo
|
433 |
-
|
434 |
-
|
435 |
-
if __name__ == "__main__":
|
436 |
-
parser = argparse.ArgumentParser()
|
437 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
438 |
-
parser.add_argument("--port", type=int)
|
439 |
-
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
440 |
-
parser.add_argument("--concurrency-count", type=int, default=10)
|
441 |
-
parser.add_argument("--model-list-mode", type=str, default="once",
|
442 |
-
choices=["once", "reload"])
|
443 |
-
parser.add_argument("--share", action="store_true")
|
444 |
-
parser.add_argument("--moderate", action="store_true")
|
445 |
-
parser.add_argument("--embed", action="store_true")
|
446 |
-
args = parser.parse_args()
|
447 |
-
logger.info(f"args: {args}")
|
448 |
-
|
449 |
-
models = get_model_list()
|
450 |
-
|
451 |
-
logger.info(args)
|
452 |
-
demo = build_demo(args.embed)
|
453 |
-
demo.queue(
|
454 |
-
concurrency_count=args.concurrency_count,
|
455 |
-
api_open=False
|
456 |
-
).launch(
|
457 |
-
server_name=args.host,
|
458 |
-
server_port=args.port,
|
459 |
-
share=False
|
460 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mplug_docowl/serve/model_worker.py
DELETED
@@ -1,342 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A model worker executes the model.
|
3 |
-
"""
|
4 |
-
import argparse
|
5 |
-
import asyncio
|
6 |
-
import json
|
7 |
-
import time
|
8 |
-
import threading
|
9 |
-
import uuid
|
10 |
-
|
11 |
-
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
-
from fastapi.responses import StreamingResponse
|
13 |
-
import requests
|
14 |
-
import torch
|
15 |
-
import uvicorn
|
16 |
-
from functools import partial
|
17 |
-
|
18 |
-
from mplug_docowl.utils import (build_logger, server_error_msg,
|
19 |
-
pretty_print_semaphore)
|
20 |
-
|
21 |
-
from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,WORKER_HEART_BEAT_INTERVAL
|
22 |
-
from mplug_docowl.conversation import conv_templates, SeparatorStyle
|
23 |
-
from mplug_docowl.model.builder import load_pretrained_model
|
24 |
-
from mplug_docowl.mm_utils import load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
25 |
-
from mplug_docowl.processor import DocProcessor
|
26 |
-
|
27 |
-
|
28 |
-
from transformers import TextIteratorStreamer
|
29 |
-
from threading import Thread
|
30 |
-
|
31 |
-
|
32 |
-
GB = 1 << 30
|
33 |
-
|
34 |
-
worker_id = str(uuid.uuid4())[:6]
|
35 |
-
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
36 |
-
global_counter = 0
|
37 |
-
|
38 |
-
model_semaphore = None
|
39 |
-
|
40 |
-
|
41 |
-
def heart_beat_worker(controller):
|
42 |
-
|
43 |
-
while True:
|
44 |
-
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
45 |
-
controller.send_heart_beat()
|
46 |
-
|
47 |
-
|
48 |
-
class DocOwlInfer():
|
49 |
-
def __init__(self, ckpt_path, anchors='grid_9', add_global_img=True, load_8bit=False, load_4bit=False):
|
50 |
-
model_name = get_model_name_from_path(ckpt_path)
|
51 |
-
ic(model_name)
|
52 |
-
self.tokenizer, self.model, _, _ = load_pretrained_model(ckpt_path, None, model_name, load_8bit=load_8bit, load_4bit=load_4bit, device="cuda")
|
53 |
-
self.doc_image_processor = DocProcessor(image_size=448, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
|
54 |
-
self.streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
55 |
-
|
56 |
-
def inference(self, image, query):
|
57 |
-
image_tensor, patch_positions, text = self.doc_image_processor(images=image, query='<|image|>'+query)
|
58 |
-
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
|
59 |
-
patch_positions = patch_positions.to(self.model.device)
|
60 |
-
|
61 |
-
# ic(image_tensor.shape, patch_positions.shape, text)
|
62 |
-
|
63 |
-
conv = conv_templates["mplug_owl2"].copy()
|
64 |
-
roles = conv.roles # ("USER", "ASSISTANT")
|
65 |
-
|
66 |
-
conv.append_message(conv.roles[0], text)
|
67 |
-
conv.append_message(conv.roles[1], None)
|
68 |
-
prompt = conv.get_prompt()
|
69 |
-
|
70 |
-
# ic(prompt)
|
71 |
-
|
72 |
-
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
|
73 |
-
|
74 |
-
# ic(input_ids)
|
75 |
-
|
76 |
-
stop_str = conv.sep2
|
77 |
-
keywords = [stop_str]
|
78 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
|
79 |
-
|
80 |
-
with torch.inference_mode():
|
81 |
-
output_ids = self.model.generate(
|
82 |
-
input_ids,
|
83 |
-
images=image_tensor,
|
84 |
-
patch_positions=patch_positions,
|
85 |
-
do_sample=False,
|
86 |
-
temperature=1.0,
|
87 |
-
max_new_tokens=512,
|
88 |
-
streamer=self.streamer,
|
89 |
-
use_cache=True,
|
90 |
-
stopping_criteria=[stopping_criteria])
|
91 |
-
|
92 |
-
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
93 |
-
|
94 |
-
return outputs.replace('</s>', '')
|
95 |
-
|
96 |
-
# TODO: adapt for docowl infer
|
97 |
-
class ModelWorker:
|
98 |
-
def __init__(self, controller_addr, worker_addr,
|
99 |
-
worker_id, no_register,
|
100 |
-
model_path, model_base, model_name,
|
101 |
-
resolution, anchors, add_global_img,
|
102 |
-
load_8bit, load_4bit, device):
|
103 |
-
self.controller_addr = controller_addr
|
104 |
-
self.worker_addr = worker_addr
|
105 |
-
self.worker_id = worker_id
|
106 |
-
if model_path.endswith("/"):
|
107 |
-
model_path = model_path[:-1]
|
108 |
-
|
109 |
-
self.model_name = get_model_name_from_path(ckpt_path)
|
110 |
-
|
111 |
-
self.device = device
|
112 |
-
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
113 |
-
|
114 |
-
self.tokenizer, self.model, _, self.context_len = load_pretrained_model(
|
115 |
-
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
116 |
-
|
117 |
-
self.resolution=resolution
|
118 |
-
self.token_num_each_img = (self.resolution/14)*(self.resolution/14)/self.model.get_model().vison2text.conv_patch
|
119 |
-
self.doc_image_processor = DocProcessor(image_size=resolution, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
|
120 |
-
|
121 |
-
|
122 |
-
self.is_multimodal = True
|
123 |
-
|
124 |
-
if not no_register:
|
125 |
-
self.register_to_controller()
|
126 |
-
self.heart_beat_thread = threading.Thread(
|
127 |
-
target=heart_beat_worker, args=(self,))
|
128 |
-
self.heart_beat_thread.start()
|
129 |
-
|
130 |
-
def register_to_controller(self):
|
131 |
-
logger.info("Register to controller")
|
132 |
-
|
133 |
-
url = self.controller_addr + "/register_worker"
|
134 |
-
data = {
|
135 |
-
"worker_name": self.worker_addr,
|
136 |
-
"check_heart_beat": True,
|
137 |
-
"worker_status": self.get_status()
|
138 |
-
}
|
139 |
-
r = requests.post(url, json=data)
|
140 |
-
assert r.status_code == 200
|
141 |
-
|
142 |
-
def send_heart_beat(self):
|
143 |
-
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
144 |
-
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
145 |
-
f"global_counter: {global_counter}")
|
146 |
-
|
147 |
-
url = self.controller_addr + "/receive_heart_beat"
|
148 |
-
|
149 |
-
while True:
|
150 |
-
try:
|
151 |
-
ret = requests.post(url, json={
|
152 |
-
"worker_name": self.worker_addr,
|
153 |
-
"queue_length": self.get_queue_length()}, timeout=5)
|
154 |
-
exist = ret.json()["exist"]
|
155 |
-
break
|
156 |
-
except requests.exceptions.RequestException as e:
|
157 |
-
logger.error(f"heart beat error: {e}")
|
158 |
-
time.sleep(5)
|
159 |
-
|
160 |
-
if not exist:
|
161 |
-
self.register_to_controller()
|
162 |
-
|
163 |
-
def get_queue_length(self):
|
164 |
-
if model_semaphore is None:
|
165 |
-
return 0
|
166 |
-
else:
|
167 |
-
return args.limit_model_concurrency - model_semaphore._value + (len(
|
168 |
-
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
169 |
-
|
170 |
-
def get_status(self):
|
171 |
-
return {
|
172 |
-
"model_names": [self.model_name],
|
173 |
-
"speed": 1,
|
174 |
-
"queue_length": self.get_queue_length(),
|
175 |
-
}
|
176 |
-
|
177 |
-
@torch.inference_mode()
|
178 |
-
def generate_stream(self, params):
|
179 |
-
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
180 |
-
|
181 |
-
prompt = params["prompt"]
|
182 |
-
ori_prompt = prompt
|
183 |
-
images = params.get("images", None)
|
184 |
-
num_image_tokens = 0
|
185 |
-
if images is not None and len(images) > 0 and self.is_multimodal:
|
186 |
-
if len(images) > 0:
|
187 |
-
|
188 |
-
"""if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
189 |
-
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
190 |
-
|
191 |
-
images = [load_image_from_base64(image) for image in images]
|
192 |
-
images = process_images(images, image_processor, model.config)
|
193 |
-
|
194 |
-
if type(images) is list:
|
195 |
-
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
196 |
-
else:
|
197 |
-
images = images.to(self.model.device, dtype=torch.float16)"""
|
198 |
-
|
199 |
-
# docowl only support 1 image, so only keep the last image
|
200 |
-
image = images[-1]
|
201 |
-
assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
|
202 |
-
|
203 |
-
image_tensor, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
|
204 |
-
image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
|
205 |
-
patch_positions = patch_positions.to(self.model.device)
|
206 |
-
|
207 |
-
replace_token = DEFAULT_IMAGE_TOKEN
|
208 |
-
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
209 |
-
num_image_tokens = prompt.count(replace_token) * (self.token_num_each_img+1)
|
210 |
-
else:
|
211 |
-
images = None
|
212 |
-
patch_positions = None
|
213 |
-
image_args = {"images": images, "patch_positions":patch_positions}
|
214 |
-
else:
|
215 |
-
images = None
|
216 |
-
image_args = {}
|
217 |
-
|
218 |
-
temperature = float(params.get("temperature", 1.0))
|
219 |
-
top_p = float(params.get("top_p", 1.0))
|
220 |
-
max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
|
221 |
-
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
222 |
-
stop_str = params.get("stop", None)
|
223 |
-
do_sample = True if temperature > 0.001 else False
|
224 |
-
|
225 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
226 |
-
keywords = [stop_str]
|
227 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
228 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
229 |
-
|
230 |
-
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
231 |
-
|
232 |
-
if max_new_tokens < 1:
|
233 |
-
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
234 |
-
return
|
235 |
-
|
236 |
-
thread = Thread(target=model.generate, kwargs=dict(
|
237 |
-
inputs=input_ids,
|
238 |
-
do_sample=do_sample,
|
239 |
-
temperature=temperature,
|
240 |
-
top_p=top_p,
|
241 |
-
max_new_tokens=max_new_tokens,
|
242 |
-
streamer=streamer,
|
243 |
-
stopping_criteria=[stopping_criteria],
|
244 |
-
use_cache=True,
|
245 |
-
**image_args
|
246 |
-
))
|
247 |
-
thread.start()
|
248 |
-
|
249 |
-
generated_text = ori_prompt
|
250 |
-
for new_text in streamer:
|
251 |
-
generated_text += new_text
|
252 |
-
if generated_text.endswith(stop_str):
|
253 |
-
generated_text = generated_text[:-len(stop_str)]
|
254 |
-
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
255 |
-
|
256 |
-
def generate_stream_gate(self, params):
|
257 |
-
try:
|
258 |
-
for x in self.generate_stream(params):
|
259 |
-
yield x
|
260 |
-
except ValueError as e:
|
261 |
-
print("Caught ValueError:", e)
|
262 |
-
ret = {
|
263 |
-
"text": server_error_msg,
|
264 |
-
"error_code": 1,
|
265 |
-
}
|
266 |
-
yield json.dumps(ret).encode() + b"\0"
|
267 |
-
except torch.cuda.CudaError as e:
|
268 |
-
print("Caught torch.cuda.CudaError:", e)
|
269 |
-
ret = {
|
270 |
-
"text": server_error_msg,
|
271 |
-
"error_code": 1,
|
272 |
-
}
|
273 |
-
yield json.dumps(ret).encode() + b"\0"
|
274 |
-
except Exception as e:
|
275 |
-
print("Caught Unknown Error", e)
|
276 |
-
ret = {
|
277 |
-
"text": server_error_msg,
|
278 |
-
"error_code": 1,
|
279 |
-
}
|
280 |
-
yield json.dumps(ret).encode() + b"\0"
|
281 |
-
|
282 |
-
app = FastAPI()
|
283 |
-
|
284 |
-
def release_model_semaphore(fn=None):
|
285 |
-
model_semaphore.release()
|
286 |
-
if fn is not None:
|
287 |
-
fn()
|
288 |
-
|
289 |
-
|
290 |
-
@app.post("/worker_generate_stream")
|
291 |
-
async def generate_stream(request: Request):
|
292 |
-
global model_semaphore, global_counter
|
293 |
-
global_counter += 1
|
294 |
-
params = await request.json()
|
295 |
-
|
296 |
-
if model_semaphore is None:
|
297 |
-
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
298 |
-
await model_semaphore.acquire()
|
299 |
-
worker.send_heart_beat()
|
300 |
-
generator = worker.generate_stream_gate(params)
|
301 |
-
background_tasks = BackgroundTasks()
|
302 |
-
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
303 |
-
return StreamingResponse(generator, background=background_tasks)
|
304 |
-
|
305 |
-
|
306 |
-
@app.post("/worker_get_status")
|
307 |
-
async def get_status(request: Request):
|
308 |
-
return worker.get_status()
|
309 |
-
|
310 |
-
|
311 |
-
if __name__ == "__main__":
|
312 |
-
parser = argparse.ArgumentParser()
|
313 |
-
parser.add_argument("--host", type=str, default="localhost")
|
314 |
-
parser.add_argument("--port", type=int, default=21002)
|
315 |
-
parser.add_argument("--worker-address", type=str,
|
316 |
-
default="http://localhost:21002")
|
317 |
-
parser.add_argument("--controller-address", type=str,
|
318 |
-
default="http://localhost:21001")
|
319 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
320 |
-
parser.add_argument("--model-base", type=str, default=None)
|
321 |
-
parser.add_argument("--model-name", type=str)
|
322 |
-
parser.add_argument("--device", type=str, default="cuda")
|
323 |
-
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
324 |
-
parser.add_argument("--stream-interval", type=int, default=1)
|
325 |
-
parser.add_argument("--no-register", action="store_true")
|
326 |
-
parser.add_argument("--load-8bit", action="store_true")
|
327 |
-
parser.add_argument("--load-4bit", action="store_true")
|
328 |
-
args = parser.parse_args()
|
329 |
-
logger.info(f"args: {args}")
|
330 |
-
|
331 |
-
|
332 |
-
worker = ModelWorker(args.controller_address,
|
333 |
-
args.worker_address,
|
334 |
-
worker_id,
|
335 |
-
args.no_register,
|
336 |
-
args.model_path,
|
337 |
-
args.model_base,
|
338 |
-
args.model_name,
|
339 |
-
args.load_8bit,
|
340 |
-
args.load_4bit,
|
341 |
-
args.device)
|
342 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mplug_docowl/serve/model_worker_bak.py
DELETED
@@ -1,278 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A model worker executes the model.
|
3 |
-
"""
|
4 |
-
import argparse
|
5 |
-
import asyncio
|
6 |
-
import json
|
7 |
-
import time
|
8 |
-
import threading
|
9 |
-
import uuid
|
10 |
-
|
11 |
-
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
-
from fastapi.responses import StreamingResponse
|
13 |
-
import requests
|
14 |
-
import torch
|
15 |
-
import uvicorn
|
16 |
-
from functools import partial
|
17 |
-
|
18 |
-
from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
|
19 |
-
from mplug_owl2.utils import (build_logger, server_error_msg,
|
20 |
-
pretty_print_semaphore)
|
21 |
-
from mplug_owl2.model.builder import load_pretrained_model
|
22 |
-
from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
|
23 |
-
from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
24 |
-
from transformers import TextIteratorStreamer
|
25 |
-
from threading import Thread
|
26 |
-
|
27 |
-
|
28 |
-
GB = 1 << 30
|
29 |
-
|
30 |
-
worker_id = str(uuid.uuid4())[:6]
|
31 |
-
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
32 |
-
global_counter = 0
|
33 |
-
|
34 |
-
model_semaphore = None
|
35 |
-
|
36 |
-
|
37 |
-
def heart_beat_worker(controller):
|
38 |
-
|
39 |
-
while True:
|
40 |
-
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
41 |
-
controller.send_heart_beat()
|
42 |
-
|
43 |
-
|
44 |
-
class ModelWorker:
|
45 |
-
def __init__(self, controller_addr, worker_addr,
|
46 |
-
worker_id, no_register,
|
47 |
-
model_path, model_base, model_name,
|
48 |
-
load_8bit, load_4bit, device):
|
49 |
-
self.controller_addr = controller_addr
|
50 |
-
self.worker_addr = worker_addr
|
51 |
-
self.worker_id = worker_id
|
52 |
-
if model_path.endswith("/"):
|
53 |
-
model_path = model_path[:-1]
|
54 |
-
if model_name is None:
|
55 |
-
model_paths = model_path.split("/")
|
56 |
-
if model_paths[-1].startswith('checkpoint-'):
|
57 |
-
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
58 |
-
else:
|
59 |
-
self.model_name = model_paths[-1]
|
60 |
-
else:
|
61 |
-
self.model_name = model_name
|
62 |
-
|
63 |
-
self.device = device
|
64 |
-
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
65 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
66 |
-
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
|
67 |
-
self.is_multimodal = True
|
68 |
-
|
69 |
-
if not no_register:
|
70 |
-
self.register_to_controller()
|
71 |
-
self.heart_beat_thread = threading.Thread(
|
72 |
-
target=heart_beat_worker, args=(self,))
|
73 |
-
self.heart_beat_thread.start()
|
74 |
-
|
75 |
-
def register_to_controller(self):
|
76 |
-
logger.info("Register to controller")
|
77 |
-
|
78 |
-
url = self.controller_addr + "/register_worker"
|
79 |
-
data = {
|
80 |
-
"worker_name": self.worker_addr,
|
81 |
-
"check_heart_beat": True,
|
82 |
-
"worker_status": self.get_status()
|
83 |
-
}
|
84 |
-
r = requests.post(url, json=data)
|
85 |
-
assert r.status_code == 200
|
86 |
-
|
87 |
-
def send_heart_beat(self):
|
88 |
-
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
89 |
-
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
90 |
-
f"global_counter: {global_counter}")
|
91 |
-
|
92 |
-
url = self.controller_addr + "/receive_heart_beat"
|
93 |
-
|
94 |
-
while True:
|
95 |
-
try:
|
96 |
-
ret = requests.post(url, json={
|
97 |
-
"worker_name": self.worker_addr,
|
98 |
-
"queue_length": self.get_queue_length()}, timeout=5)
|
99 |
-
exist = ret.json()["exist"]
|
100 |
-
break
|
101 |
-
except requests.exceptions.RequestException as e:
|
102 |
-
logger.error(f"heart beat error: {e}")
|
103 |
-
time.sleep(5)
|
104 |
-
|
105 |
-
if not exist:
|
106 |
-
self.register_to_controller()
|
107 |
-
|
108 |
-
def get_queue_length(self):
|
109 |
-
if model_semaphore is None:
|
110 |
-
return 0
|
111 |
-
else:
|
112 |
-
return args.limit_model_concurrency - model_semaphore._value + (len(
|
113 |
-
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
114 |
-
|
115 |
-
def get_status(self):
|
116 |
-
return {
|
117 |
-
"model_names": [self.model_name],
|
118 |
-
"speed": 1,
|
119 |
-
"queue_length": self.get_queue_length(),
|
120 |
-
}
|
121 |
-
|
122 |
-
@torch.inference_mode()
|
123 |
-
def generate_stream(self, params):
|
124 |
-
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
125 |
-
|
126 |
-
prompt = params["prompt"]
|
127 |
-
ori_prompt = prompt
|
128 |
-
images = params.get("images", None)
|
129 |
-
num_image_tokens = 0
|
130 |
-
if images is not None and len(images) > 0 and self.is_multimodal:
|
131 |
-
if len(images) > 0:
|
132 |
-
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
133 |
-
raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
|
134 |
-
|
135 |
-
images = [load_image_from_base64(image) for image in images]
|
136 |
-
images = process_images(images, image_processor, model.config)
|
137 |
-
|
138 |
-
if type(images) is list:
|
139 |
-
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
140 |
-
else:
|
141 |
-
images = images.to(self.model.device, dtype=torch.float16)
|
142 |
-
|
143 |
-
replace_token = DEFAULT_IMAGE_TOKEN
|
144 |
-
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
145 |
-
|
146 |
-
num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
|
147 |
-
else:
|
148 |
-
images = None
|
149 |
-
image_args = {"images": images}
|
150 |
-
else:
|
151 |
-
images = None
|
152 |
-
image_args = {}
|
153 |
-
|
154 |
-
temperature = float(params.get("temperature", 1.0))
|
155 |
-
top_p = float(params.get("top_p", 1.0))
|
156 |
-
max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
|
157 |
-
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
158 |
-
stop_str = params.get("stop", None)
|
159 |
-
do_sample = True if temperature > 0.001 else False
|
160 |
-
|
161 |
-
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
|
162 |
-
keywords = [stop_str]
|
163 |
-
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
164 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
|
165 |
-
|
166 |
-
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
|
167 |
-
|
168 |
-
if max_new_tokens < 1:
|
169 |
-
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
170 |
-
return
|
171 |
-
|
172 |
-
thread = Thread(target=model.generate, kwargs=dict(
|
173 |
-
inputs=input_ids,
|
174 |
-
do_sample=do_sample,
|
175 |
-
temperature=temperature,
|
176 |
-
top_p=top_p,
|
177 |
-
max_new_tokens=max_new_tokens,
|
178 |
-
streamer=streamer,
|
179 |
-
stopping_criteria=[stopping_criteria],
|
180 |
-
use_cache=True,
|
181 |
-
**image_args
|
182 |
-
))
|
183 |
-
thread.start()
|
184 |
-
|
185 |
-
generated_text = ori_prompt
|
186 |
-
for new_text in streamer:
|
187 |
-
generated_text += new_text
|
188 |
-
if generated_text.endswith(stop_str):
|
189 |
-
generated_text = generated_text[:-len(stop_str)]
|
190 |
-
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
191 |
-
|
192 |
-
def generate_stream_gate(self, params):
|
193 |
-
try:
|
194 |
-
for x in self.generate_stream(params):
|
195 |
-
yield x
|
196 |
-
except ValueError as e:
|
197 |
-
print("Caught ValueError:", e)
|
198 |
-
ret = {
|
199 |
-
"text": server_error_msg,
|
200 |
-
"error_code": 1,
|
201 |
-
}
|
202 |
-
yield json.dumps(ret).encode() + b"\0"
|
203 |
-
except torch.cuda.CudaError as e:
|
204 |
-
print("Caught torch.cuda.CudaError:", e)
|
205 |
-
ret = {
|
206 |
-
"text": server_error_msg,
|
207 |
-
"error_code": 1,
|
208 |
-
}
|
209 |
-
yield json.dumps(ret).encode() + b"\0"
|
210 |
-
except Exception as e:
|
211 |
-
print("Caught Unknown Error", e)
|
212 |
-
ret = {
|
213 |
-
"text": server_error_msg,
|
214 |
-
"error_code": 1,
|
215 |
-
}
|
216 |
-
yield json.dumps(ret).encode() + b"\0"
|
217 |
-
|
218 |
-
app = FastAPI()
|
219 |
-
|
220 |
-
def release_model_semaphore(fn=None):
|
221 |
-
model_semaphore.release()
|
222 |
-
if fn is not None:
|
223 |
-
fn()
|
224 |
-
|
225 |
-
|
226 |
-
@app.post("/worker_generate_stream")
|
227 |
-
async def generate_stream(request: Request):
|
228 |
-
global model_semaphore, global_counter
|
229 |
-
global_counter += 1
|
230 |
-
params = await request.json()
|
231 |
-
|
232 |
-
if model_semaphore is None:
|
233 |
-
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
234 |
-
await model_semaphore.acquire()
|
235 |
-
worker.send_heart_beat()
|
236 |
-
generator = worker.generate_stream_gate(params)
|
237 |
-
background_tasks = BackgroundTasks()
|
238 |
-
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
239 |
-
return StreamingResponse(generator, background=background_tasks)
|
240 |
-
|
241 |
-
|
242 |
-
@app.post("/worker_get_status")
|
243 |
-
async def get_status(request: Request):
|
244 |
-
return worker.get_status()
|
245 |
-
|
246 |
-
|
247 |
-
if __name__ == "__main__":
|
248 |
-
parser = argparse.ArgumentParser()
|
249 |
-
parser.add_argument("--host", type=str, default="localhost")
|
250 |
-
parser.add_argument("--port", type=int, default=21002)
|
251 |
-
parser.add_argument("--worker-address", type=str,
|
252 |
-
default="http://localhost:21002")
|
253 |
-
parser.add_argument("--controller-address", type=str,
|
254 |
-
default="http://localhost:21001")
|
255 |
-
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
256 |
-
parser.add_argument("--model-base", type=str, default=None)
|
257 |
-
parser.add_argument("--model-name", type=str)
|
258 |
-
parser.add_argument("--device", type=str, default="cuda")
|
259 |
-
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
260 |
-
parser.add_argument("--stream-interval", type=int, default=1)
|
261 |
-
parser.add_argument("--no-register", action="store_true")
|
262 |
-
parser.add_argument("--load-8bit", action="store_true")
|
263 |
-
parser.add_argument("--load-4bit", action="store_true")
|
264 |
-
args = parser.parse_args()
|
265 |
-
logger.info(f"args: {args}")
|
266 |
-
|
267 |
-
|
268 |
-
worker = ModelWorker(args.controller_address,
|
269 |
-
args.worker_address,
|
270 |
-
worker_id,
|
271 |
-
args.no_register,
|
272 |
-
args.model_path,
|
273 |
-
args.model_base,
|
274 |
-
args.model_name,
|
275 |
-
args.load_8bit,
|
276 |
-
args.load_4bit,
|
277 |
-
args.device)
|
278 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mplug_docowl/serve/register_workers.py
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Manually register workers.
|
3 |
-
|
4 |
-
Usage:
|
5 |
-
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
|
6 |
-
"""
|
7 |
-
|
8 |
-
import argparse
|
9 |
-
|
10 |
-
import requests
|
11 |
-
|
12 |
-
if __name__ == "__main__":
|
13 |
-
parser = argparse.ArgumentParser()
|
14 |
-
parser.add_argument("--controller-address", type=str)
|
15 |
-
parser.add_argument("--worker-name", type=str)
|
16 |
-
parser.add_argument("--check-heart-beat", action="store_true")
|
17 |
-
args = parser.parse_args()
|
18 |
-
|
19 |
-
url = args.controller_address + "/register_worker"
|
20 |
-
data = {
|
21 |
-
"worker_name": args.worker_name,
|
22 |
-
"check_heart_beat": args.check_heart_beat,
|
23 |
-
"worker_status": None,
|
24 |
-
}
|
25 |
-
r = requests.post(url, json=data)
|
26 |
-
assert r.status_code == 200
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|