AnwenHu commited on
Commit
a44bd40
1 Parent(s): 7914e1b

Delete mplug_docowl/serve

Browse files
mplug_docowl/serve/__init__.py DELETED
File without changes
mplug_docowl/serve/cli.py DELETED
@@ -1,120 +0,0 @@
1
- import argparse
2
- import torch
3
-
4
- from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
- from mplug_owl2.conversation import conv_templates, SeparatorStyle
6
- from mplug_owl2.model.builder import load_pretrained_model
7
- from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
-
9
- from PIL import Image
10
-
11
- import requests
12
- from PIL import Image
13
- from io import BytesIO
14
- from transformers import TextStreamer
15
-
16
-
17
- def disable_torch_init():
18
- """
19
- Disable the redundant torch default initialization to accelerate model creation.
20
- """
21
- import torch
22
- setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23
- setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24
-
25
-
26
- def load_image(image_file):
27
- if image_file.startswith('http://') or image_file.startswith('https://'):
28
- response = requests.get(image_file)
29
- image = Image.open(BytesIO(response.content)).convert('RGB')
30
- else:
31
- image = Image.open(image_file).convert('RGB')
32
- return image
33
-
34
-
35
- def main(args):
36
- # Model
37
- disable_torch_init()
38
-
39
- model_name = get_model_name_from_path(args.model_path)
40
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
41
-
42
- conv_mode = "mplug_owl2"
43
-
44
- if args.conv_mode is not None and conv_mode != args.conv_mode:
45
- print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
46
- else:
47
- args.conv_mode = conv_mode
48
-
49
- conv = conv_templates[args.conv_mode].copy()
50
- roles = conv.roles
51
-
52
- image = load_image(args.image_file)
53
- # Similar operation in model_worker.py
54
- image_tensor = process_images([image], image_processor, args)
55
- if type(image_tensor) is list:
56
- image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
57
- else:
58
- image_tensor = image_tensor.to(model.device, dtype=torch.float16)
59
-
60
- while True:
61
- try:
62
- inp = input(f"{roles[0]}: ")
63
- except EOFError:
64
- inp = ""
65
- if not inp:
66
- print("exit...")
67
- break
68
-
69
- print(f"{roles[1]}: ", end="")
70
-
71
- if image is not None:
72
- # first message
73
- inp = DEFAULT_IMAGE_TOKEN + inp
74
- conv.append_message(conv.roles[0], inp)
75
- image = None
76
- else:
77
- # later messages
78
- conv.append_message(conv.roles[0], inp)
79
- conv.append_message(conv.roles[1], None)
80
- prompt = conv.get_prompt()
81
-
82
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
83
- stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
84
- keywords = [stop_str]
85
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
86
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
87
-
88
- with torch.inference_mode():
89
- output_ids = model.generate(
90
- input_ids,
91
- images=image_tensor,
92
- do_sample=True,
93
- temperature=args.temperature,
94
- max_new_tokens=args.max_new_tokens,
95
- streamer=streamer,
96
- use_cache=True,
97
- stopping_criteria=[stopping_criteria])
98
-
99
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
100
- conv.messages[-1][-1] = outputs
101
-
102
- if args.debug:
103
- print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
104
-
105
-
106
- if __name__ == "__main__":
107
- parser = argparse.ArgumentParser()
108
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
109
- parser.add_argument("--model-base", type=str, default=None)
110
- parser.add_argument("--image-file", type=str, required=True)
111
- parser.add_argument("--device", type=str, default="cuda")
112
- parser.add_argument("--conv-mode", type=str, default=None)
113
- parser.add_argument("--temperature", type=float, default=0.2)
114
- parser.add_argument("--max-new-tokens", type=int, default=512)
115
- parser.add_argument("--load-8bit", action="store_true")
116
- parser.add_argument("--load-4bit", action="store_true")
117
- parser.add_argument("--debug", action="store_true")
118
- parser.add_argument("--image-aspect-ratio", type=str, default='pad')
119
- args = parser.parse_args()
120
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mplug_docowl/serve/controller.py DELETED
@@ -1,298 +0,0 @@
1
- """
2
- A controller manages distributed workers.
3
- It sends worker addresses to clients.
4
- """
5
- import argparse
6
- import asyncio
7
- import dataclasses
8
- from enum import Enum, auto
9
- import json
10
- import logging
11
- import time
12
- from typing import List, Union
13
- import threading
14
-
15
- from fastapi import FastAPI, Request
16
- from fastapi.responses import StreamingResponse
17
- import numpy as np
18
- import requests
19
- import uvicorn
20
-
21
- from mplug_owl2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
- from mplug_owl2.utils import build_logger, server_error_msg
23
-
24
-
25
- logger = build_logger("controller", "controller.log")
26
-
27
-
28
- class DispatchMethod(Enum):
29
- LOTTERY = auto()
30
- SHORTEST_QUEUE = auto()
31
-
32
- @classmethod
33
- def from_str(cls, name):
34
- if name == "lottery":
35
- return cls.LOTTERY
36
- elif name == "shortest_queue":
37
- return cls.SHORTEST_QUEUE
38
- else:
39
- raise ValueError(f"Invalid dispatch method")
40
-
41
-
42
- @dataclasses.dataclass
43
- class WorkerInfo:
44
- model_names: List[str]
45
- speed: int
46
- queue_length: int
47
- check_heart_beat: bool
48
- last_heart_beat: str
49
-
50
-
51
- def heart_beat_controller(controller):
52
- while True:
53
- time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
- controller.remove_stable_workers_by_expiration()
55
-
56
-
57
- class Controller:
58
- def __init__(self, dispatch_method: str):
59
- # Dict[str -> WorkerInfo]
60
- self.worker_info = {}
61
- self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
-
63
- self.heart_beat_thread = threading.Thread(
64
- target=heart_beat_controller, args=(self,))
65
- self.heart_beat_thread.start()
66
-
67
- logger.info("Init controller")
68
-
69
- def register_worker(self, worker_name: str, check_heart_beat: bool,
70
- worker_status: dict):
71
- if worker_name not in self.worker_info:
72
- logger.info(f"Register a new worker: {worker_name}")
73
- else:
74
- logger.info(f"Register an existing worker: {worker_name}")
75
-
76
- if not worker_status:
77
- worker_status = self.get_worker_status(worker_name)
78
- if not worker_status:
79
- return False
80
-
81
- self.worker_info[worker_name] = WorkerInfo(
82
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
- check_heart_beat, time.time())
84
-
85
- logger.info(f"Register done: {worker_name}, {worker_status}")
86
- return True
87
-
88
- def get_worker_status(self, worker_name: str):
89
- try:
90
- r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
- except requests.exceptions.RequestException as e:
92
- logger.error(f"Get status fails: {worker_name}, {e}")
93
- return None
94
-
95
- if r.status_code != 200:
96
- logger.error(f"Get status fails: {worker_name}, {r}")
97
- return None
98
-
99
- return r.json()
100
-
101
- def remove_worker(self, worker_name: str):
102
- del self.worker_info[worker_name]
103
-
104
- def refresh_all_workers(self):
105
- old_info = dict(self.worker_info)
106
- self.worker_info = {}
107
-
108
- for w_name, w_info in old_info.items():
109
- if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
- logger.info(f"Remove stale worker: {w_name}")
111
-
112
- def list_models(self):
113
- model_names = set()
114
-
115
- for w_name, w_info in self.worker_info.items():
116
- model_names.update(w_info.model_names)
117
-
118
- return list(model_names)
119
-
120
- def get_worker_address(self, model_name: str):
121
- if self.dispatch_method == DispatchMethod.LOTTERY:
122
- worker_names = []
123
- worker_speeds = []
124
- for w_name, w_info in self.worker_info.items():
125
- if model_name in w_info.model_names:
126
- worker_names.append(w_name)
127
- worker_speeds.append(w_info.speed)
128
- worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
- norm = np.sum(worker_speeds)
130
- if norm < 1e-4:
131
- return ""
132
- worker_speeds = worker_speeds / norm
133
- if True: # Directly return address
134
- pt = np.random.choice(np.arange(len(worker_names)),
135
- p=worker_speeds)
136
- worker_name = worker_names[pt]
137
- return worker_name
138
-
139
- # Check status before returning
140
- while True:
141
- pt = np.random.choice(np.arange(len(worker_names)),
142
- p=worker_speeds)
143
- worker_name = worker_names[pt]
144
-
145
- if self.get_worker_status(worker_name):
146
- break
147
- else:
148
- self.remove_worker(worker_name)
149
- worker_speeds[pt] = 0
150
- norm = np.sum(worker_speeds)
151
- if norm < 1e-4:
152
- return ""
153
- worker_speeds = worker_speeds / norm
154
- continue
155
- return worker_name
156
- elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
- worker_names = []
158
- worker_qlen = []
159
- for w_name, w_info in self.worker_info.items():
160
- if model_name in w_info.model_names:
161
- worker_names.append(w_name)
162
- worker_qlen.append(w_info.queue_length / w_info.speed)
163
- if len(worker_names) == 0:
164
- return ""
165
- min_index = np.argmin(worker_qlen)
166
- w_name = worker_names[min_index]
167
- self.worker_info[w_name].queue_length += 1
168
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
- return w_name
170
- else:
171
- raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
-
173
- def receive_heart_beat(self, worker_name: str, queue_length: int):
174
- if worker_name not in self.worker_info:
175
- logger.info(f"Receive unknown heart beat. {worker_name}")
176
- return False
177
-
178
- self.worker_info[worker_name].queue_length = queue_length
179
- self.worker_info[worker_name].last_heart_beat = time.time()
180
- logger.info(f"Receive heart beat. {worker_name}")
181
- return True
182
-
183
- def remove_stable_workers_by_expiration(self):
184
- expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
- to_delete = []
186
- for worker_name, w_info in self.worker_info.items():
187
- if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
- to_delete.append(worker_name)
189
-
190
- for worker_name in to_delete:
191
- self.remove_worker(worker_name)
192
-
193
- def worker_api_generate_stream(self, params):
194
- worker_addr = self.get_worker_address(params["model"])
195
- if not worker_addr:
196
- logger.info(f"no worker: {params['model']}")
197
- ret = {
198
- "text": server_error_msg,
199
- "error_code": 2,
200
- }
201
- yield json.dumps(ret).encode() + b"\0"
202
-
203
- try:
204
- response = requests.post(worker_addr + "/worker_generate_stream",
205
- json=params, stream=True, timeout=5)
206
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
- if chunk:
208
- yield chunk + b"\0"
209
- except requests.exceptions.RequestException as e:
210
- logger.info(f"worker timeout: {worker_addr}")
211
- ret = {
212
- "text": server_error_msg,
213
- "error_code": 3,
214
- }
215
- yield json.dumps(ret).encode() + b"\0"
216
-
217
-
218
- # Let the controller act as a worker to achieve hierarchical
219
- # management. This can be used to connect isolated sub networks.
220
- def worker_api_get_status(self):
221
- model_names = set()
222
- speed = 0
223
- queue_length = 0
224
-
225
- for w_name in self.worker_info:
226
- worker_status = self.get_worker_status(w_name)
227
- if worker_status is not None:
228
- model_names.update(worker_status["model_names"])
229
- speed += worker_status["speed"]
230
- queue_length += worker_status["queue_length"]
231
-
232
- return {
233
- "model_names": list(model_names),
234
- "speed": speed,
235
- "queue_length": queue_length,
236
- }
237
-
238
-
239
- app = FastAPI()
240
-
241
-
242
- @app.post("/register_worker")
243
- async def register_worker(request: Request):
244
- data = await request.json()
245
- controller.register_worker(
246
- data["worker_name"], data["check_heart_beat"],
247
- data.get("worker_status", None))
248
-
249
-
250
- @app.post("/refresh_all_workers")
251
- async def refresh_all_workers():
252
- models = controller.refresh_all_workers()
253
-
254
-
255
- @app.post("/list_models")
256
- async def list_models():
257
- models = controller.list_models()
258
- return {"models": models}
259
-
260
-
261
- @app.post("/get_worker_address")
262
- async def get_worker_address(request: Request):
263
- data = await request.json()
264
- addr = controller.get_worker_address(data["model"])
265
- return {"address": addr}
266
-
267
-
268
- @app.post("/receive_heart_beat")
269
- async def receive_heart_beat(request: Request):
270
- data = await request.json()
271
- exist = controller.receive_heart_beat(
272
- data["worker_name"], data["queue_length"])
273
- return {"exist": exist}
274
-
275
-
276
- @app.post("/worker_generate_stream")
277
- async def worker_api_generate_stream(request: Request):
278
- params = await request.json()
279
- generator = controller.worker_api_generate_stream(params)
280
- return StreamingResponse(generator)
281
-
282
-
283
- @app.post("/worker_get_status")
284
- async def worker_api_get_status(request: Request):
285
- return controller.worker_api_get_status()
286
-
287
-
288
- if __name__ == "__main__":
289
- parser = argparse.ArgumentParser()
290
- parser.add_argument("--host", type=str, default="localhost")
291
- parser.add_argument("--port", type=int, default=21001)
292
- parser.add_argument("--dispatch-method", type=str, choices=[
293
- "lottery", "shortest_queue"], default="shortest_queue")
294
- args = parser.parse_args()
295
- logger.info(f"args: {args}")
296
-
297
- controller = Controller(args.dispatch_method)
298
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg DELETED
Binary file (18.9 kB)
 
mplug_docowl/serve/examples/extreme_ironing.jpg DELETED
Binary file (62.6 kB)
 
mplug_docowl/serve/gradio_web_server.py DELETED
@@ -1,460 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
-
7
- import gradio as gr
8
- import requests
9
-
10
- from mplug_owl2.conversation import (default_conversation, conv_templates,
11
- SeparatorStyle)
12
- from mplug_owl2.constants import LOGDIR
13
- from mplug_owl2.utils import (build_logger, server_error_msg,
14
- violates_moderation, moderation_msg)
15
- import hashlib
16
-
17
-
18
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
-
20
- headers = {"User-Agent": "mPLUG-Owl2 Client"}
21
-
22
- no_change_btn = gr.Button.update()
23
- enable_btn = gr.Button.update(interactive=True)
24
- disable_btn = gr.Button.update(interactive=False)
25
-
26
- priority = {
27
- "vicuna-13b": "aaaaaaa",
28
- "koala-13b": "aaaaaab",
29
- }
30
-
31
-
32
- def get_conv_log_filename():
33
- t = datetime.datetime.now()
34
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
- return name
36
-
37
-
38
- def get_model_list():
39
- ret = requests.post(args.controller_url + "/refresh_all_workers")
40
- assert ret.status_code == 200
41
- ret = requests.post(args.controller_url + "/list_models")
42
- models = ret.json()["models"]
43
- models.sort(key=lambda x: priority.get(x, x))
44
- logger.info(f"Models: {models}")
45
- return models
46
-
47
-
48
- get_window_url_params = """
49
- function() {
50
- const params = new URLSearchParams(window.location.search);
51
- url_params = Object.fromEntries(params);
52
- console.log(url_params);
53
- return url_params;
54
- }
55
- """
56
-
57
-
58
- def load_demo(url_params, request: gr.Request):
59
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
-
61
- dropdown_update = gr.Dropdown.update(visible=True)
62
- if "model" in url_params:
63
- model = url_params["model"]
64
- if model in models:
65
- dropdown_update = gr.Dropdown.update(
66
- value=model, visible=True)
67
-
68
- state = default_conversation.copy()
69
- return state, dropdown_update
70
-
71
-
72
- def load_demo_refresh_model_list(request: gr.Request):
73
- logger.info(f"load_demo. ip: {request.client.host}")
74
- models = get_model_list()
75
- state = default_conversation.copy()
76
- dropdown_update = gr.Dropdown.update(
77
- choices=models,
78
- value=models[0] if len(models) > 0 else ""
79
- )
80
- return state, dropdown_update
81
-
82
-
83
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
- with open(get_conv_log_filename(), "a") as fout:
85
- data = {
86
- "tstamp": round(time.time(), 4),
87
- "type": vote_type,
88
- "model": model_selector,
89
- "state": state.dict(),
90
- "ip": request.client.host,
91
- }
92
- fout.write(json.dumps(data) + "\n")
93
-
94
-
95
- def upvote_last_response(state, model_selector, request: gr.Request):
96
- logger.info(f"upvote. ip: {request.client.host}")
97
- vote_last_response(state, "upvote", model_selector, request)
98
- return ("",) + (disable_btn,) * 3
99
-
100
-
101
- def downvote_last_response(state, model_selector, request: gr.Request):
102
- logger.info(f"downvote. ip: {request.client.host}")
103
- vote_last_response(state, "downvote", model_selector, request)
104
- return ("",) + (disable_btn,) * 3
105
-
106
-
107
- def flag_last_response(state, model_selector, request: gr.Request):
108
- logger.info(f"flag. ip: {request.client.host}")
109
- vote_last_response(state, "flag", model_selector, request)
110
- return ("",) + (disable_btn,) * 3
111
-
112
-
113
- def regenerate(state, image_process_mode, request: gr.Request):
114
- logger.info(f"regenerate. ip: {request.client.host}")
115
- state.messages[-1][-1] = None
116
- prev_human_msg = state.messages[-2]
117
- if type(prev_human_msg[1]) in (tuple, list):
118
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
- state.skip_next = False
120
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
-
122
-
123
- def clear_history(request: gr.Request):
124
- logger.info(f"clear_history. ip: {request.client.host}")
125
- state = default_conversation.copy()
126
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
-
128
-
129
- def add_text(state, text, image, image_process_mode, request: gr.Request):
130
- logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
- if len(text) <= 0 and image is None:
132
- state.skip_next = True
133
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
- if args.moderate:
135
- flagged = violates_moderation(text)
136
- if flagged:
137
- state.skip_next = True
138
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
- no_change_btn,) * 5
140
-
141
- text = text[:1536] # Hard cut-off
142
- if image is not None:
143
- text = text[:1200] # Hard cut-off for images
144
- if '<|image|>' not in text:
145
- # text = text + '<|image|>'
146
- text = '<|image|>' + text
147
- text = (text, image, image_process_mode)
148
- if len(state.get_images(return_pil=True)) > 0:
149
- state = default_conversation.copy()
150
- state.append_message(state.roles[0], text)
151
- state.append_message(state.roles[1], None)
152
- state.skip_next = False
153
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
-
155
-
156
- def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
- logger.info(f"http_bot. ip: {request.client.host}")
158
- start_tstamp = time.time()
159
- model_name = model_selector
160
-
161
- if state.skip_next:
162
- # This generate call is skipped due to invalid inputs
163
- yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
- return
165
-
166
- if len(state.messages) == state.offset + 2:
167
- # First round of conversation
168
- template_name = "mplug_owl2"
169
- new_state = conv_templates[template_name].copy()
170
- new_state.append_message(new_state.roles[0], state.messages[-2][1])
171
- new_state.append_message(new_state.roles[1], None)
172
- state = new_state
173
-
174
- # Query worker address
175
- controller_url = args.controller_url
176
- ret = requests.post(controller_url + "/get_worker_address",
177
- json={"model": model_name})
178
- worker_addr = ret.json()["address"]
179
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
180
-
181
- # No available worker
182
- if worker_addr == "":
183
- state.messages[-1][-1] = server_error_msg
184
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
185
- return
186
-
187
- # Construct prompt
188
- prompt = state.get_prompt()
189
-
190
- all_images = state.get_images(return_pil=True)
191
- all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
192
- for image, hash in zip(all_images, all_image_hash):
193
- t = datetime.datetime.now()
194
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
195
- if not os.path.isfile(filename):
196
- os.makedirs(os.path.dirname(filename), exist_ok=True)
197
- image.save(filename)
198
-
199
- # Make requests
200
- pload = {
201
- "model": model_name,
202
- "prompt": prompt,
203
- "temperature": float(temperature),
204
- "top_p": float(top_p),
205
- "max_new_tokens": min(int(max_new_tokens), 1536),
206
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
207
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
208
- }
209
- logger.info(f"==== request ====\n{pload}")
210
-
211
- pload['images'] = state.get_images()
212
-
213
- state.messages[-1][-1] = "▌"
214
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
215
-
216
- try:
217
- # Stream output
218
- response = requests.post(worker_addr + "/worker_generate_stream",
219
- headers=headers, json=pload, stream=True, timeout=10)
220
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
221
- if chunk:
222
- data = json.loads(chunk.decode())
223
- if data["error_code"] == 0:
224
- output = data["text"][len(prompt):].strip()
225
- state.messages[-1][-1] = output + "▌"
226
- yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
227
- else:
228
- output = data["text"] + f" (error_code: {data['error_code']})"
229
- state.messages[-1][-1] = output
230
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
231
- return
232
- time.sleep(0.03)
233
- except requests.exceptions.RequestException as e:
234
- state.messages[-1][-1] = server_error_msg
235
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
236
- return
237
-
238
- state.messages[-1][-1] = state.messages[-1][-1][:-1]
239
- yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
240
-
241
- finish_tstamp = time.time()
242
- logger.info(f"{output}")
243
-
244
- with open(get_conv_log_filename(), "a") as fout:
245
- data = {
246
- "tstamp": round(finish_tstamp, 4),
247
- "type": "chat",
248
- "model": model_name,
249
- "start": round(start_tstamp, 4),
250
- "finish": round(start_tstamp, 4),
251
- "state": state.dict(),
252
- "images": all_image_hash,
253
- "ip": request.client.host,
254
- }
255
- fout.write(json.dumps(data) + "\n")
256
-
257
-
258
- title_markdown = ("""
259
- <h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
260
-
261
- <h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
262
-
263
- <h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
264
-
265
- <div align="center">
266
- <div style="display:flex; gap: 0.25rem;" align="center">
267
- <a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
268
- <a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
269
- <a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
270
- </div>
271
- </div>
272
-
273
- """)
274
-
275
-
276
- tos_markdown = ("""
277
- ### Terms of use
278
- By using this service, users are required to agree to the following terms:
279
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
280
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
281
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
282
- """)
283
-
284
-
285
- learn_more_markdown = ("""
286
- ### License
287
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
288
- """)
289
-
290
- block_css = """
291
-
292
- #buttons button {
293
- min-width: min(120px,100%);
294
- }
295
-
296
- """
297
-
298
- def build_demo(embed_mode):
299
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
300
- with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
301
- state = gr.State()
302
-
303
- if not embed_mode:
304
- gr.Markdown(title_markdown)
305
-
306
- with gr.Row():
307
- with gr.Column(scale=3):
308
- with gr.Row(elem_id="model_selector_row"):
309
- model_selector = gr.Dropdown(
310
- choices=models,
311
- value=models[0] if len(models) > 0 else "",
312
- interactive=True,
313
- show_label=False,
314
- container=False)
315
-
316
- imagebox = gr.Image(type="pil")
317
- image_process_mode = gr.Radio(
318
- ["Crop", "Resize", "Pad", "Default"],
319
- value="Default",
320
- label="Preprocess for non-square image", visible=False)
321
-
322
- cur_dir = os.path.dirname(os.path.abspath(__file__))
323
- gr.Examples(examples=[
324
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
325
- [f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
326
- ], inputs=[imagebox, textbox])
327
-
328
- with gr.Accordion("Parameters", open=True) as parameter_row:
329
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
330
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
331
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
332
-
333
- with gr.Column(scale=8):
334
- chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
335
- with gr.Row():
336
- with gr.Column(scale=8):
337
- textbox.render()
338
- with gr.Column(scale=1, min_width=50):
339
- submit_btn = gr.Button(value="Send", variant="primary")
340
- with gr.Row(elem_id="buttons") as button_row:
341
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
342
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
343
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
344
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
345
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
346
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
347
-
348
- if not embed_mode:
349
- gr.Markdown(tos_markdown)
350
- gr.Markdown(learn_more_markdown)
351
- url_params = gr.JSON(visible=False)
352
-
353
- # Register listeners
354
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
355
- upvote_btn.click(
356
- upvote_last_response,
357
- [state, model_selector],
358
- [textbox, upvote_btn, downvote_btn, flag_btn],
359
- queue=False
360
- )
361
- downvote_btn.click(
362
- downvote_last_response,
363
- [state, model_selector],
364
- [textbox, upvote_btn, downvote_btn, flag_btn],
365
- queue=False
366
- )
367
- flag_btn.click(
368
- flag_last_response,
369
- [state, model_selector],
370
- [textbox, upvote_btn, downvote_btn, flag_btn],
371
- queue=False
372
- )
373
-
374
- regenerate_btn.click(
375
- regenerate,
376
- [state, image_process_mode],
377
- [state, chatbot, textbox, imagebox] + btn_list,
378
- queue=False
379
- ).then(
380
- http_bot,
381
- [state, model_selector, temperature, top_p, max_output_tokens],
382
- [state, chatbot] + btn_list
383
- )
384
-
385
- clear_btn.click(
386
- clear_history,
387
- None,
388
- [state, chatbot, textbox, imagebox] + btn_list,
389
- queue=False
390
- )
391
-
392
- textbox.submit(
393
- add_text,
394
- [state, textbox, imagebox, image_process_mode],
395
- [state, chatbot, textbox, imagebox] + btn_list,
396
- queue=False
397
- ).then(
398
- http_bot,
399
- [state, model_selector, temperature, top_p, max_output_tokens],
400
- [state, chatbot] + btn_list
401
- )
402
-
403
- submit_btn.click(
404
- add_text,
405
- [state, textbox, imagebox, image_process_mode],
406
- [state, chatbot, textbox, imagebox] + btn_list,
407
- queue=False
408
- ).then(
409
- http_bot,
410
- [state, model_selector, temperature, top_p, max_output_tokens],
411
- [state, chatbot] + btn_list
412
- )
413
-
414
- if args.model_list_mode == "once":
415
- demo.load(
416
- load_demo,
417
- [url_params],
418
- [state, model_selector],
419
- _js=get_window_url_params,
420
- queue=False
421
- )
422
- elif args.model_list_mode == "reload":
423
- demo.load(
424
- load_demo_refresh_model_list,
425
- None,
426
- [state, model_selector],
427
- queue=False
428
- )
429
- else:
430
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
431
-
432
- return demo
433
-
434
-
435
- if __name__ == "__main__":
436
- parser = argparse.ArgumentParser()
437
- parser.add_argument("--host", type=str, default="0.0.0.0")
438
- parser.add_argument("--port", type=int)
439
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
440
- parser.add_argument("--concurrency-count", type=int, default=10)
441
- parser.add_argument("--model-list-mode", type=str, default="once",
442
- choices=["once", "reload"])
443
- parser.add_argument("--share", action="store_true")
444
- parser.add_argument("--moderate", action="store_true")
445
- parser.add_argument("--embed", action="store_true")
446
- args = parser.parse_args()
447
- logger.info(f"args: {args}")
448
-
449
- models = get_model_list()
450
-
451
- logger.info(args)
452
- demo = build_demo(args.embed)
453
- demo.queue(
454
- concurrency_count=args.concurrency_count,
455
- api_open=False
456
- ).launch(
457
- server_name=args.host,
458
- server_port=args.port,
459
- share=False
460
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mplug_docowl/serve/model_worker.py DELETED
@@ -1,342 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
-
11
- from fastapi import FastAPI, Request, BackgroundTasks
12
- from fastapi.responses import StreamingResponse
13
- import requests
14
- import torch
15
- import uvicorn
16
- from functools import partial
17
-
18
- from mplug_docowl.utils import (build_logger, server_error_msg,
19
- pretty_print_semaphore)
20
-
21
- from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,WORKER_HEART_BEAT_INTERVAL
22
- from mplug_docowl.conversation import conv_templates, SeparatorStyle
23
- from mplug_docowl.model.builder import load_pretrained_model
24
- from mplug_docowl.mm_utils import load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
25
- from mplug_docowl.processor import DocProcessor
26
-
27
-
28
- from transformers import TextIteratorStreamer
29
- from threading import Thread
30
-
31
-
32
- GB = 1 << 30
33
-
34
- worker_id = str(uuid.uuid4())[:6]
35
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
36
- global_counter = 0
37
-
38
- model_semaphore = None
39
-
40
-
41
- def heart_beat_worker(controller):
42
-
43
- while True:
44
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
45
- controller.send_heart_beat()
46
-
47
-
48
- class DocOwlInfer():
49
- def __init__(self, ckpt_path, anchors='grid_9', add_global_img=True, load_8bit=False, load_4bit=False):
50
- model_name = get_model_name_from_path(ckpt_path)
51
- ic(model_name)
52
- self.tokenizer, self.model, _, _ = load_pretrained_model(ckpt_path, None, model_name, load_8bit=load_8bit, load_4bit=load_4bit, device="cuda")
53
- self.doc_image_processor = DocProcessor(image_size=448, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
54
- self.streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
55
-
56
- def inference(self, image, query):
57
- image_tensor, patch_positions, text = self.doc_image_processor(images=image, query='<|image|>'+query)
58
- image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
59
- patch_positions = patch_positions.to(self.model.device)
60
-
61
- # ic(image_tensor.shape, patch_positions.shape, text)
62
-
63
- conv = conv_templates["mplug_owl2"].copy()
64
- roles = conv.roles # ("USER", "ASSISTANT")
65
-
66
- conv.append_message(conv.roles[0], text)
67
- conv.append_message(conv.roles[1], None)
68
- prompt = conv.get_prompt()
69
-
70
- # ic(prompt)
71
-
72
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
73
-
74
- # ic(input_ids)
75
-
76
- stop_str = conv.sep2
77
- keywords = [stop_str]
78
- stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
79
-
80
- with torch.inference_mode():
81
- output_ids = self.model.generate(
82
- input_ids,
83
- images=image_tensor,
84
- patch_positions=patch_positions,
85
- do_sample=False,
86
- temperature=1.0,
87
- max_new_tokens=512,
88
- streamer=self.streamer,
89
- use_cache=True,
90
- stopping_criteria=[stopping_criteria])
91
-
92
- outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
93
-
94
- return outputs.replace('</s>', '')
95
-
96
- # TODO: adapt for docowl infer
97
- class ModelWorker:
98
- def __init__(self, controller_addr, worker_addr,
99
- worker_id, no_register,
100
- model_path, model_base, model_name,
101
- resolution, anchors, add_global_img,
102
- load_8bit, load_4bit, device):
103
- self.controller_addr = controller_addr
104
- self.worker_addr = worker_addr
105
- self.worker_id = worker_id
106
- if model_path.endswith("/"):
107
- model_path = model_path[:-1]
108
-
109
- self.model_name = get_model_name_from_path(ckpt_path)
110
-
111
- self.device = device
112
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
113
-
114
- self.tokenizer, self.model, _, self.context_len = load_pretrained_model(
115
- model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
116
-
117
- self.resolution=resolution
118
- self.token_num_each_img = (self.resolution/14)*(self.resolution/14)/self.model.get_model().vison2text.conv_patch
119
- self.doc_image_processor = DocProcessor(image_size=resolution, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
120
-
121
-
122
- self.is_multimodal = True
123
-
124
- if not no_register:
125
- self.register_to_controller()
126
- self.heart_beat_thread = threading.Thread(
127
- target=heart_beat_worker, args=(self,))
128
- self.heart_beat_thread.start()
129
-
130
- def register_to_controller(self):
131
- logger.info("Register to controller")
132
-
133
- url = self.controller_addr + "/register_worker"
134
- data = {
135
- "worker_name": self.worker_addr,
136
- "check_heart_beat": True,
137
- "worker_status": self.get_status()
138
- }
139
- r = requests.post(url, json=data)
140
- assert r.status_code == 200
141
-
142
- def send_heart_beat(self):
143
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
144
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
145
- f"global_counter: {global_counter}")
146
-
147
- url = self.controller_addr + "/receive_heart_beat"
148
-
149
- while True:
150
- try:
151
- ret = requests.post(url, json={
152
- "worker_name": self.worker_addr,
153
- "queue_length": self.get_queue_length()}, timeout=5)
154
- exist = ret.json()["exist"]
155
- break
156
- except requests.exceptions.RequestException as e:
157
- logger.error(f"heart beat error: {e}")
158
- time.sleep(5)
159
-
160
- if not exist:
161
- self.register_to_controller()
162
-
163
- def get_queue_length(self):
164
- if model_semaphore is None:
165
- return 0
166
- else:
167
- return args.limit_model_concurrency - model_semaphore._value + (len(
168
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
169
-
170
- def get_status(self):
171
- return {
172
- "model_names": [self.model_name],
173
- "speed": 1,
174
- "queue_length": self.get_queue_length(),
175
- }
176
-
177
- @torch.inference_mode()
178
- def generate_stream(self, params):
179
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
180
-
181
- prompt = params["prompt"]
182
- ori_prompt = prompt
183
- images = params.get("images", None)
184
- num_image_tokens = 0
185
- if images is not None and len(images) > 0 and self.is_multimodal:
186
- if len(images) > 0:
187
-
188
- """if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
189
- raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
190
-
191
- images = [load_image_from_base64(image) for image in images]
192
- images = process_images(images, image_processor, model.config)
193
-
194
- if type(images) is list:
195
- images = [image.to(self.model.device, dtype=torch.float16) for image in images]
196
- else:
197
- images = images.to(self.model.device, dtype=torch.float16)"""
198
-
199
- # docowl only support 1 image, so only keep the last image
200
- image = images[-1]
201
- assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
202
-
203
- image_tensor, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
204
- image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
205
- patch_positions = patch_positions.to(self.model.device)
206
-
207
- replace_token = DEFAULT_IMAGE_TOKEN
208
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
209
- num_image_tokens = prompt.count(replace_token) * (self.token_num_each_img+1)
210
- else:
211
- images = None
212
- patch_positions = None
213
- image_args = {"images": images, "patch_positions":patch_positions}
214
- else:
215
- images = None
216
- image_args = {}
217
-
218
- temperature = float(params.get("temperature", 1.0))
219
- top_p = float(params.get("top_p", 1.0))
220
- max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
221
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
222
- stop_str = params.get("stop", None)
223
- do_sample = True if temperature > 0.001 else False
224
-
225
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
226
- keywords = [stop_str]
227
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
228
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
229
-
230
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
231
-
232
- if max_new_tokens < 1:
233
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
234
- return
235
-
236
- thread = Thread(target=model.generate, kwargs=dict(
237
- inputs=input_ids,
238
- do_sample=do_sample,
239
- temperature=temperature,
240
- top_p=top_p,
241
- max_new_tokens=max_new_tokens,
242
- streamer=streamer,
243
- stopping_criteria=[stopping_criteria],
244
- use_cache=True,
245
- **image_args
246
- ))
247
- thread.start()
248
-
249
- generated_text = ori_prompt
250
- for new_text in streamer:
251
- generated_text += new_text
252
- if generated_text.endswith(stop_str):
253
- generated_text = generated_text[:-len(stop_str)]
254
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
255
-
256
- def generate_stream_gate(self, params):
257
- try:
258
- for x in self.generate_stream(params):
259
- yield x
260
- except ValueError as e:
261
- print("Caught ValueError:", e)
262
- ret = {
263
- "text": server_error_msg,
264
- "error_code": 1,
265
- }
266
- yield json.dumps(ret).encode() + b"\0"
267
- except torch.cuda.CudaError as e:
268
- print("Caught torch.cuda.CudaError:", e)
269
- ret = {
270
- "text": server_error_msg,
271
- "error_code": 1,
272
- }
273
- yield json.dumps(ret).encode() + b"\0"
274
- except Exception as e:
275
- print("Caught Unknown Error", e)
276
- ret = {
277
- "text": server_error_msg,
278
- "error_code": 1,
279
- }
280
- yield json.dumps(ret).encode() + b"\0"
281
-
282
- app = FastAPI()
283
-
284
- def release_model_semaphore(fn=None):
285
- model_semaphore.release()
286
- if fn is not None:
287
- fn()
288
-
289
-
290
- @app.post("/worker_generate_stream")
291
- async def generate_stream(request: Request):
292
- global model_semaphore, global_counter
293
- global_counter += 1
294
- params = await request.json()
295
-
296
- if model_semaphore is None:
297
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
298
- await model_semaphore.acquire()
299
- worker.send_heart_beat()
300
- generator = worker.generate_stream_gate(params)
301
- background_tasks = BackgroundTasks()
302
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
303
- return StreamingResponse(generator, background=background_tasks)
304
-
305
-
306
- @app.post("/worker_get_status")
307
- async def get_status(request: Request):
308
- return worker.get_status()
309
-
310
-
311
- if __name__ == "__main__":
312
- parser = argparse.ArgumentParser()
313
- parser.add_argument("--host", type=str, default="localhost")
314
- parser.add_argument("--port", type=int, default=21002)
315
- parser.add_argument("--worker-address", type=str,
316
- default="http://localhost:21002")
317
- parser.add_argument("--controller-address", type=str,
318
- default="http://localhost:21001")
319
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
320
- parser.add_argument("--model-base", type=str, default=None)
321
- parser.add_argument("--model-name", type=str)
322
- parser.add_argument("--device", type=str, default="cuda")
323
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
324
- parser.add_argument("--stream-interval", type=int, default=1)
325
- parser.add_argument("--no-register", action="store_true")
326
- parser.add_argument("--load-8bit", action="store_true")
327
- parser.add_argument("--load-4bit", action="store_true")
328
- args = parser.parse_args()
329
- logger.info(f"args: {args}")
330
-
331
-
332
- worker = ModelWorker(args.controller_address,
333
- args.worker_address,
334
- worker_id,
335
- args.no_register,
336
- args.model_path,
337
- args.model_base,
338
- args.model_name,
339
- args.load_8bit,
340
- args.load_4bit,
341
- args.device)
342
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mplug_docowl/serve/model_worker_bak.py DELETED
@@ -1,278 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
-
11
- from fastapi import FastAPI, Request, BackgroundTasks
12
- from fastapi.responses import StreamingResponse
13
- import requests
14
- import torch
15
- import uvicorn
16
- from functools import partial
17
-
18
- from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
19
- from mplug_owl2.utils import (build_logger, server_error_msg,
20
- pretty_print_semaphore)
21
- from mplug_owl2.model.builder import load_pretrained_model
22
- from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23
- from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
24
- from transformers import TextIteratorStreamer
25
- from threading import Thread
26
-
27
-
28
- GB = 1 << 30
29
-
30
- worker_id = str(uuid.uuid4())[:6]
31
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
- global_counter = 0
33
-
34
- model_semaphore = None
35
-
36
-
37
- def heart_beat_worker(controller):
38
-
39
- while True:
40
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
- controller.send_heart_beat()
42
-
43
-
44
- class ModelWorker:
45
- def __init__(self, controller_addr, worker_addr,
46
- worker_id, no_register,
47
- model_path, model_base, model_name,
48
- load_8bit, load_4bit, device):
49
- self.controller_addr = controller_addr
50
- self.worker_addr = worker_addr
51
- self.worker_id = worker_id
52
- if model_path.endswith("/"):
53
- model_path = model_path[:-1]
54
- if model_name is None:
55
- model_paths = model_path.split("/")
56
- if model_paths[-1].startswith('checkpoint-'):
57
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
- else:
59
- self.model_name = model_paths[-1]
60
- else:
61
- self.model_name = model_name
62
-
63
- self.device = device
64
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
- model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
- self.is_multimodal = True
68
-
69
- if not no_register:
70
- self.register_to_controller()
71
- self.heart_beat_thread = threading.Thread(
72
- target=heart_beat_worker, args=(self,))
73
- self.heart_beat_thread.start()
74
-
75
- def register_to_controller(self):
76
- logger.info("Register to controller")
77
-
78
- url = self.controller_addr + "/register_worker"
79
- data = {
80
- "worker_name": self.worker_addr,
81
- "check_heart_beat": True,
82
- "worker_status": self.get_status()
83
- }
84
- r = requests.post(url, json=data)
85
- assert r.status_code == 200
86
-
87
- def send_heart_beat(self):
88
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
- f"global_counter: {global_counter}")
91
-
92
- url = self.controller_addr + "/receive_heart_beat"
93
-
94
- while True:
95
- try:
96
- ret = requests.post(url, json={
97
- "worker_name": self.worker_addr,
98
- "queue_length": self.get_queue_length()}, timeout=5)
99
- exist = ret.json()["exist"]
100
- break
101
- except requests.exceptions.RequestException as e:
102
- logger.error(f"heart beat error: {e}")
103
- time.sleep(5)
104
-
105
- if not exist:
106
- self.register_to_controller()
107
-
108
- def get_queue_length(self):
109
- if model_semaphore is None:
110
- return 0
111
- else:
112
- return args.limit_model_concurrency - model_semaphore._value + (len(
113
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
-
115
- def get_status(self):
116
- return {
117
- "model_names": [self.model_name],
118
- "speed": 1,
119
- "queue_length": self.get_queue_length(),
120
- }
121
-
122
- @torch.inference_mode()
123
- def generate_stream(self, params):
124
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
-
126
- prompt = params["prompt"]
127
- ori_prompt = prompt
128
- images = params.get("images", None)
129
- num_image_tokens = 0
130
- if images is not None and len(images) > 0 and self.is_multimodal:
131
- if len(images) > 0:
132
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
- raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
134
-
135
- images = [load_image_from_base64(image) for image in images]
136
- images = process_images(images, image_processor, model.config)
137
-
138
- if type(images) is list:
139
- images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140
- else:
141
- images = images.to(self.model.device, dtype=torch.float16)
142
-
143
- replace_token = DEFAULT_IMAGE_TOKEN
144
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
145
-
146
- num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
147
- else:
148
- images = None
149
- image_args = {"images": images}
150
- else:
151
- images = None
152
- image_args = {}
153
-
154
- temperature = float(params.get("temperature", 1.0))
155
- top_p = float(params.get("top_p", 1.0))
156
- max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
157
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
158
- stop_str = params.get("stop", None)
159
- do_sample = True if temperature > 0.001 else False
160
-
161
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
162
- keywords = [stop_str]
163
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
164
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
165
-
166
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
167
-
168
- if max_new_tokens < 1:
169
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
170
- return
171
-
172
- thread = Thread(target=model.generate, kwargs=dict(
173
- inputs=input_ids,
174
- do_sample=do_sample,
175
- temperature=temperature,
176
- top_p=top_p,
177
- max_new_tokens=max_new_tokens,
178
- streamer=streamer,
179
- stopping_criteria=[stopping_criteria],
180
- use_cache=True,
181
- **image_args
182
- ))
183
- thread.start()
184
-
185
- generated_text = ori_prompt
186
- for new_text in streamer:
187
- generated_text += new_text
188
- if generated_text.endswith(stop_str):
189
- generated_text = generated_text[:-len(stop_str)]
190
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
191
-
192
- def generate_stream_gate(self, params):
193
- try:
194
- for x in self.generate_stream(params):
195
- yield x
196
- except ValueError as e:
197
- print("Caught ValueError:", e)
198
- ret = {
199
- "text": server_error_msg,
200
- "error_code": 1,
201
- }
202
- yield json.dumps(ret).encode() + b"\0"
203
- except torch.cuda.CudaError as e:
204
- print("Caught torch.cuda.CudaError:", e)
205
- ret = {
206
- "text": server_error_msg,
207
- "error_code": 1,
208
- }
209
- yield json.dumps(ret).encode() + b"\0"
210
- except Exception as e:
211
- print("Caught Unknown Error", e)
212
- ret = {
213
- "text": server_error_msg,
214
- "error_code": 1,
215
- }
216
- yield json.dumps(ret).encode() + b"\0"
217
-
218
- app = FastAPI()
219
-
220
- def release_model_semaphore(fn=None):
221
- model_semaphore.release()
222
- if fn is not None:
223
- fn()
224
-
225
-
226
- @app.post("/worker_generate_stream")
227
- async def generate_stream(request: Request):
228
- global model_semaphore, global_counter
229
- global_counter += 1
230
- params = await request.json()
231
-
232
- if model_semaphore is None:
233
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
234
- await model_semaphore.acquire()
235
- worker.send_heart_beat()
236
- generator = worker.generate_stream_gate(params)
237
- background_tasks = BackgroundTasks()
238
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
239
- return StreamingResponse(generator, background=background_tasks)
240
-
241
-
242
- @app.post("/worker_get_status")
243
- async def get_status(request: Request):
244
- return worker.get_status()
245
-
246
-
247
- if __name__ == "__main__":
248
- parser = argparse.ArgumentParser()
249
- parser.add_argument("--host", type=str, default="localhost")
250
- parser.add_argument("--port", type=int, default=21002)
251
- parser.add_argument("--worker-address", type=str,
252
- default="http://localhost:21002")
253
- parser.add_argument("--controller-address", type=str,
254
- default="http://localhost:21001")
255
- parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
256
- parser.add_argument("--model-base", type=str, default=None)
257
- parser.add_argument("--model-name", type=str)
258
- parser.add_argument("--device", type=str, default="cuda")
259
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
260
- parser.add_argument("--stream-interval", type=int, default=1)
261
- parser.add_argument("--no-register", action="store_true")
262
- parser.add_argument("--load-8bit", action="store_true")
263
- parser.add_argument("--load-4bit", action="store_true")
264
- args = parser.parse_args()
265
- logger.info(f"args: {args}")
266
-
267
-
268
- worker = ModelWorker(args.controller_address,
269
- args.worker_address,
270
- worker_id,
271
- args.no_register,
272
- args.model_path,
273
- args.model_base,
274
- args.model_name,
275
- args.load_8bit,
276
- args.load_4bit,
277
- args.device)
278
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mplug_docowl/serve/register_workers.py DELETED
@@ -1,26 +0,0 @@
1
- """
2
- Manually register workers.
3
-
4
- Usage:
5
- python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
- """
7
-
8
- import argparse
9
-
10
- import requests
11
-
12
- if __name__ == "__main__":
13
- parser = argparse.ArgumentParser()
14
- parser.add_argument("--controller-address", type=str)
15
- parser.add_argument("--worker-name", type=str)
16
- parser.add_argument("--check-heart-beat", action="store_true")
17
- args = parser.parse_args()
18
-
19
- url = args.controller_address + "/register_worker"
20
- data = {
21
- "worker_name": args.worker_name,
22
- "check_heart_beat": args.check_heart_beat,
23
- "worker_status": None,
24
- }
25
- r = requests.post(url, json=data)
26
- assert r.status_code == 200