AnwenHu commited on
Commit
d87616f
1 Parent(s): beee9d0

Upload 52 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mplug_docowl/__init__.py +2 -0
  2. mplug_docowl/__pycache__/__init__.cpython-310.pyc +0 -0
  3. mplug_docowl/__pycache__/constants.cpython-310.pyc +0 -0
  4. mplug_docowl/__pycache__/conversation.cpython-310.pyc +0 -0
  5. mplug_docowl/__pycache__/mm_utils.cpython-310.pyc +0 -0
  6. mplug_docowl/__pycache__/processor.cpython-310.pyc +0 -0
  7. mplug_docowl/__pycache__/utils.cpython-310.pyc +0 -0
  8. mplug_docowl/constants.py +9 -0
  9. mplug_docowl/conversation.py +301 -0
  10. mplug_docowl/local_serve/__init__.py +0 -0
  11. mplug_docowl/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg +0 -0
  12. mplug_docowl/local_serve/examples/extreme_ironing.jpg +0 -0
  13. mplug_docowl/local_serve/local_web_server.py +392 -0
  14. mplug_docowl/local_serve/model_worker.py +143 -0
  15. mplug_docowl/mm_utils.py +112 -0
  16. mplug_docowl/model/__init__.py +2 -0
  17. mplug_docowl/model/__pycache__/__init__.cpython-310.pyc +0 -0
  18. mplug_docowl/model/__pycache__/builder.cpython-310.pyc +0 -0
  19. mplug_docowl/model/__pycache__/configuration_mplug_docowl.cpython-310.pyc +0 -0
  20. mplug_docowl/model/__pycache__/configuration_mplug_docowl2.cpython-310.pyc +0 -0
  21. mplug_docowl/model/__pycache__/convert_mplug_docowl2_weight_to_hf.cpython-310.pyc +0 -0
  22. mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf.cpython-310.pyc +0 -0
  23. mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf_v2.cpython-310.pyc +0 -0
  24. mplug_docowl/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc +0 -0
  25. mplug_docowl/model/__pycache__/modeling_llama2.cpython-310.pyc +0 -0
  26. mplug_docowl/model/__pycache__/modeling_mplug_docowl.cpython-310.pyc +0 -0
  27. mplug_docowl/model/__pycache__/modeling_mplug_docowl2.cpython-310.pyc +0 -0
  28. mplug_docowl/model/__pycache__/visual_encoder.cpython-310.pyc +0 -0
  29. mplug_docowl/model/builder.py +81 -0
  30. mplug_docowl/model/configuration_mplug_docowl.py +318 -0
  31. mplug_docowl/model/convert_mplug_docowl_weight_to_hf.py +319 -0
  32. mplug_docowl/model/convert_mplug_docowl_weight_to_hf_v2.py +320 -0
  33. mplug_docowl/model/modeling_attn_mask_utils.py +247 -0
  34. mplug_docowl/model/modeling_llama2.py +486 -0
  35. mplug_docowl/model/modeling_mplug_docowl.py +313 -0
  36. mplug_docowl/model/utils.py +20 -0
  37. mplug_docowl/model/visual_encoder.py +499 -0
  38. mplug_docowl/processor.py +219 -0
  39. mplug_docowl/serve/__init__.py +0 -0
  40. mplug_docowl/serve/cli.py +120 -0
  41. mplug_docowl/serve/controller.py +298 -0
  42. mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg +0 -0
  43. mplug_docowl/serve/examples/extreme_ironing.jpg +0 -0
  44. mplug_docowl/serve/gradio_web_server.py +460 -0
  45. mplug_docowl/serve/model_worker.py +342 -0
  46. mplug_docowl/serve/model_worker_bak.py +278 -0
  47. mplug_docowl/serve/register_workers.py +26 -0
  48. mplug_docowl/train/llama_flash_attn_monkey_patch.py +117 -0
  49. mplug_docowl/train/mplug_owl2_trainer.py +243 -0
  50. mplug_docowl/train/train.py +801 -0
mplug_docowl/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import MPLUGDocOwlLlamaForCausalLM
2
+ from .processor import DocProcessor
mplug_docowl/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (271 Bytes). View file
 
mplug_docowl/__pycache__/constants.cpython-310.pyc ADDED
Binary file (368 Bytes). View file
 
mplug_docowl/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (8.56 kB). View file
 
mplug_docowl/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (4.58 kB). View file
 
mplug_docowl/__pycache__/processor.cpython-310.pyc ADDED
Binary file (6.68 kB). View file
 
mplug_docowl/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
mplug_docowl/constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "./demo_logs"
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<|image|>"
mplug_docowl/conversation.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ from mplug_docowl.constants import DEFAULT_IMAGE_TOKEN
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ TWO_NO_SYS = auto()
11
+ MPT = auto()
12
+ PLAIN = auto()
13
+ LLAMA_2 = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ messages = self.messages
32
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
33
+ messages = self.messages.copy()
34
+ init_role, init_msg = messages[0].copy()
35
+ # init_msg = init_msg[0].replace("<image>", "").strip()
36
+ # if 'mmtag' in self.version:
37
+ # messages[0] = (init_role, init_msg)
38
+ # messages.insert(0, (self.roles[0], "<Image><image></Image>"))
39
+ # messages.insert(1, (self.roles[1], "Received."))
40
+ # else:
41
+ # messages[0] = (init_role, "<image>\n" + init_msg)
42
+ init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
43
+ messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
65
+ seps = [self.sep, self.sep2]
66
+ ret = ""
67
+ for i, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + seps[i % 2]
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.MPT:
75
+ ret = self.system + self.sep
76
+ for role, message in messages:
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + message + self.sep
81
+ else:
82
+ ret += role
83
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
84
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
85
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
86
+ ret = ""
87
+
88
+ for i, (role, message) in enumerate(messages):
89
+ if i == 0:
90
+ assert message, "first message should not be none"
91
+ assert role == self.roles[0], "first message should come from user"
92
+ if message:
93
+ if type(message) is tuple:
94
+ message, _, _ = message
95
+ if i == 0: message = wrap_sys(self.system) + message
96
+ if i % 2 == 0:
97
+ message = wrap_inst(message)
98
+ ret += self.sep + message
99
+ else:
100
+ ret += " " + message + " " + self.sep2
101
+ else:
102
+ ret += ""
103
+ ret = ret.lstrip(self.sep)
104
+ elif self.sep_style == SeparatorStyle.PLAIN:
105
+ seps = [self.sep, self.sep2]
106
+ ret = self.system
107
+ for i, (role, message) in enumerate(messages):
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, _, _ = message
111
+ ret += message + seps[i % 2]
112
+ else:
113
+ ret += ""
114
+ else:
115
+ raise ValueError(f"Invalid style: {self.sep_style}")
116
+
117
+ return ret
118
+
119
+ def append_message(self, role, message):
120
+ self.messages.append([role, message])
121
+
122
+ def get_images(self, return_pil=False):
123
+ images = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ from PIL import Image
130
+ msg, image, image_process_mode = msg
131
+ if image_process_mode == "Pad":
132
+ def expand2square(pil_img, background_color=(122, 116, 104)):
133
+ width, height = pil_img.size
134
+ if width == height:
135
+ return pil_img
136
+ elif width > height:
137
+ result = Image.new(pil_img.mode, (width, width), background_color)
138
+ result.paste(pil_img, (0, (width - height) // 2))
139
+ return result
140
+ else:
141
+ result = Image.new(pil_img.mode, (height, height), background_color)
142
+ result.paste(pil_img, ((height - width) // 2, 0))
143
+ return result
144
+ image = expand2square(image)
145
+ elif image_process_mode in ["Default", "Crop"]:
146
+ pass
147
+ elif image_process_mode == "Resize":
148
+ image = image.resize((336, 336))
149
+ else:
150
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if longest_edge != max(image.size):
158
+ if H > W:
159
+ H, W = longest_edge, shortest_edge
160
+ else:
161
+ H, W = shortest_edge, longest_edge
162
+ image = image.resize((W, H))
163
+ if return_pil:
164
+ images.append(image)
165
+ else:
166
+ buffered = BytesIO()
167
+ image.save(buffered, format="PNG")
168
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
+ images.append(img_b64_str)
170
+ return images
171
+
172
+ def to_gradio_chatbot(self):
173
+ ret = []
174
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
175
+ if i % 2 == 0:
176
+ if type(msg) is tuple:
177
+ import base64
178
+ from io import BytesIO
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ msg = img_str + msg.replace('<|image|>', '').strip()
196
+ ret.append([msg, None])
197
+ else:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret[-1][-1] = msg
201
+ return ret
202
+
203
+ def copy(self):
204
+ return Conversation(
205
+ system=self.system,
206
+ roles=self.roles,
207
+ messages=[[x, y] for x, y in self.messages],
208
+ offset=self.offset,
209
+ sep_style=self.sep_style,
210
+ sep=self.sep,
211
+ sep2=self.sep2,
212
+ version=self.version)
213
+
214
+ def dict(self):
215
+ if len(self.get_images()) > 0:
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+ return {
225
+ "system": self.system,
226
+ "roles": self.roles,
227
+ "messages": self.messages,
228
+ "offset": self.offset,
229
+ "sep": self.sep,
230
+ "sep2": self.sep2,
231
+ }
232
+
233
+
234
+ conv_vicuna_v0 = Conversation(
235
+ system="A chat between a curious human and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
237
+ roles=("Human", "Assistant"),
238
+ messages=(
239
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
240
+ ("Assistant",
241
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
242
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
243
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
244
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
245
+ "renewable and non-renewable energy sources:\n"
246
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
247
+ "energy sources are finite and will eventually run out.\n"
248
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
249
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
250
+ "and other negative effects.\n"
251
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
252
+ "have lower operational costs than non-renewable sources.\n"
253
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
254
+ "locations than non-renewable sources.\n"
255
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
256
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
257
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
258
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
259
+ ),
260
+ offset=2,
261
+ sep_style=SeparatorStyle.SINGLE,
262
+ sep="###",
263
+ )
264
+
265
+ conv_vicuna_v1 = Conversation(
266
+ system="A chat between a curious user and an artificial intelligence assistant. "
267
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="v1",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.TWO,
273
+ sep=" ",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_mplug_owl2 = Conversation(
278
+ system="A chat between a curious human and an artificial intelligence assistant. "
279
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
280
+ roles=("USER", "ASSISTANT"),
281
+ version="v1",
282
+ messages=(),
283
+ offset=0,
284
+ sep_style=SeparatorStyle.TWO_NO_SYS,
285
+ sep=" ",
286
+ sep2="</s>",
287
+ )
288
+
289
+ # default_conversation = conv_vicuna_v1
290
+ default_conversation = conv_mplug_owl2
291
+ conv_templates = {
292
+ "default": conv_vicuna_v0,
293
+ "v0": conv_vicuna_v0,
294
+ "v1": conv_vicuna_v1,
295
+ "vicuna_v1": conv_vicuna_v1,
296
+ "mplug_owl2": conv_mplug_owl2,
297
+ }
298
+
299
+
300
+ if __name__ == "__main__":
301
+ print(default_conversation.get_prompt())
mplug_docowl/local_serve/__init__.py ADDED
File without changes
mplug_docowl/local_serve/examples/Rebecca_(1939_poster)_Small.jpeg ADDED
mplug_docowl/local_serve/examples/extreme_ironing.jpg ADDED
mplug_docowl/local_serve/local_web_server.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from mplug_owl2.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from mplug_owl2.constants import LOGDIR
13
+ from mplug_owl2.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ from .model_worker import ModelWorker
16
+ import hashlib
17
+
18
+ logger = build_logger("gradio_web_server_local", "gradio_web_server_local.log")
19
+
20
+ headers = {"User-Agent": "mPLUG-Owl2 Client"}
21
+
22
+ no_change_btn = gr.Button.update()
23
+ enable_btn = gr.Button.update(interactive=True)
24
+ disable_btn = gr.Button.update(interactive=False)
25
+
26
+ def get_conv_log_filename():
27
+ t = datetime.datetime.now()
28
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
29
+ return name
30
+
31
+ get_window_url_params = """
32
+ function() {
33
+ const params = new URLSearchParams(window.location.search);
34
+ url_params = Object.fromEntries(params);
35
+ console.log(url_params);
36
+ return url_params;
37
+ }
38
+ """
39
+
40
+
41
+ def load_demo(url_params, request: gr.Request):
42
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
43
+ state = default_conversation.copy()
44
+ return state
45
+
46
+
47
+ def vote_last_response(state, vote_type, request: gr.Request):
48
+ with open(get_conv_log_filename(), "a") as fout:
49
+ data = {
50
+ "tstamp": round(time.time(), 4),
51
+ "type": vote_type,
52
+ "state": state.dict(),
53
+ "ip": request.client.host,
54
+ }
55
+ fout.write(json.dumps(data) + "\n")
56
+
57
+
58
+ def upvote_last_response(state, request: gr.Request):
59
+ logger.info(f"upvote. ip: {request.client.host}")
60
+ vote_last_response(state, "upvote", request)
61
+ return ("",) + (disable_btn,) * 3
62
+
63
+
64
+ def downvote_last_response(state, request: gr.Request):
65
+ logger.info(f"downvote. ip: {request.client.host}")
66
+ vote_last_response(state, "downvote", request)
67
+ return ("",) + (disable_btn,) * 3
68
+
69
+
70
+ def flag_last_response(state, request: gr.Request):
71
+ logger.info(f"flag. ip: {request.client.host}")
72
+ vote_last_response(state, "flag", request)
73
+ return ("",) + (disable_btn,) * 3
74
+
75
+
76
+ def regenerate(state, image_process_mode, request: gr.Request):
77
+ logger.info(f"regenerate. ip: {request.client.host}")
78
+ state.messages[-1][-1] = None
79
+ prev_human_msg = state.messages[-2]
80
+ if type(prev_human_msg[1]) in (tuple, list):
81
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
82
+ state.skip_next = False
83
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
84
+
85
+
86
+ def clear_history(request: gr.Request):
87
+ logger.info(f"clear_history. ip: {request.client.host}")
88
+ state = default_conversation.copy()
89
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
90
+
91
+
92
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
93
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
94
+ if len(text) <= 0 and image is None:
95
+ state.skip_next = True
96
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
97
+ if args.moderate:
98
+ flagged = violates_moderation(text)
99
+ if flagged:
100
+ state.skip_next = True
101
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
102
+ no_change_btn,) * 5
103
+
104
+ text = text[:3584] # Hard cut-off
105
+ if image is not None:
106
+ text = text[:3500] # Hard cut-off for images
107
+ if '<|image|>' not in text:
108
+ text = '<|image|>' + text
109
+ text = (text, image, image_process_mode)
110
+ if len(state.get_images(return_pil=True)) > 0:
111
+ state = default_conversation.copy()
112
+ state.append_message(state.roles[0], text)
113
+ state.append_message(state.roles[1], None)
114
+ state.skip_next = False
115
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
116
+
117
+
118
+ def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request):
119
+ logger.info(f"http_bot. ip: {request.client.host}")
120
+ start_tstamp = time.time()
121
+
122
+ if state.skip_next:
123
+ # This generate call is skipped due to invalid inputs
124
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
125
+ return
126
+
127
+ if len(state.messages) == state.offset + 2:
128
+ # First round of conversation
129
+ template_name = "mplug_owl2"
130
+ new_state = conv_templates[template_name].copy()
131
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
132
+ new_state.append_message(new_state.roles[1], None)
133
+ state = new_state
134
+
135
+ # Construct prompt
136
+ prompt = state.get_prompt()
137
+
138
+ all_images = state.get_images(return_pil=True)
139
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
140
+ for image, hash in zip(all_images, all_image_hash):
141
+ t = datetime.datetime.now()
142
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
143
+ if not os.path.isfile(filename):
144
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
145
+ image.save(filename)
146
+
147
+ # Make requests
148
+ pload = {
149
+ "prompt": prompt,
150
+ "temperature": float(temperature),
151
+ "top_p": float(top_p),
152
+ "max_new_tokens": min(int(max_new_tokens), 2048),
153
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
154
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
155
+ }
156
+ logger.info(f"==== request ====\n{pload}")
157
+
158
+ pload['images'] = state.get_images()
159
+
160
+ state.messages[-1][-1] = "▌"
161
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
162
+
163
+ try:
164
+ # Stream output
165
+ # response = requests.post(worker_addr + "/worker_generate_stream",
166
+ # headers=headers, json=pload, stream=True, timeout=10)
167
+ # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
168
+ response = model.generate_stream_gate(pload)
169
+ for chunk in response:
170
+ if chunk:
171
+ data = json.loads(chunk.decode())
172
+ if data["error_code"] == 0:
173
+ output = data["text"][len(prompt):].strip()
174
+ state.messages[-1][-1] = output + "▌"
175
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
176
+ else:
177
+ output = data["text"] + f" (error_code: {data['error_code']})"
178
+ state.messages[-1][-1] = output
179
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
180
+ return
181
+ time.sleep(0.03)
182
+ except requests.exceptions.RequestException as e:
183
+ state.messages[-1][-1] = server_error_msg
184
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
185
+ return
186
+
187
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
188
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
189
+
190
+ finish_tstamp = time.time()
191
+ logger.info(f"{output}")
192
+
193
+ with open(get_conv_log_filename(), "a") as fout:
194
+ data = {
195
+ "tstamp": round(finish_tstamp, 4),
196
+ "type": "chat",
197
+ "start": round(start_tstamp, 4),
198
+ "finish": round(start_tstamp, 4),
199
+ "state": state.dict(),
200
+ "images": all_image_hash,
201
+ "ip": request.client.host,
202
+ }
203
+ fout.write(json.dumps(data) + "\n")
204
+
205
+
206
+ title_markdown = ("""
207
+ <h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
208
+
209
+ <h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
210
+
211
+ <h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
212
+
213
+ <div align="center">
214
+ <div style="display:flex; gap: 0.25rem;" align="center">
215
+ <a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
216
+ <a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
217
+ <a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
218
+ </div>
219
+ </div>
220
+
221
+ """)
222
+
223
+
224
+ tos_markdown = ("""
225
+ ### Terms of use
226
+ By using this service, users are required to agree to the following terms:
227
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
228
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
229
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
230
+ """)
231
+
232
+
233
+ learn_more_markdown = ("""
234
+ ### License
235
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
236
+ """)
237
+
238
+ block_css = """
239
+
240
+ #buttons button {
241
+ min-width: min(120px,100%);
242
+ }
243
+
244
+ """
245
+
246
+ def build_demo(embed_mode):
247
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
248
+ with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
249
+ state = gr.State()
250
+
251
+ if not embed_mode:
252
+ gr.Markdown(title_markdown)
253
+
254
+ with gr.Row():
255
+ with gr.Column(scale=3):
256
+ imagebox = gr.Image(type="pil")
257
+ image_process_mode = gr.Radio(
258
+ ["Crop", "Resize", "Pad", "Default"],
259
+ value="Default",
260
+ label="Preprocess for non-square image", visible=False)
261
+
262
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
263
+ gr.Examples(examples=[
264
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
265
+ [f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
266
+ ], inputs=[imagebox, textbox])
267
+
268
+ with gr.Accordion("Parameters", open=True) as parameter_row:
269
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
270
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
271
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
272
+
273
+ with gr.Column(scale=8):
274
+ chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
275
+ with gr.Row():
276
+ with gr.Column(scale=8):
277
+ textbox.render()
278
+ with gr.Column(scale=1, min_width=50):
279
+ submit_btn = gr.Button(value="Send", variant="primary")
280
+ with gr.Row(elem_id="buttons") as button_row:
281
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
282
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
283
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
284
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
285
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
286
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
287
+
288
+ if not embed_mode:
289
+ gr.Markdown(tos_markdown)
290
+ gr.Markdown(learn_more_markdown)
291
+ url_params = gr.JSON(visible=False)
292
+
293
+ # Register listeners
294
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
295
+ upvote_btn.click(
296
+ upvote_last_response,
297
+ state,
298
+ [textbox, upvote_btn, downvote_btn, flag_btn],
299
+ queue=False
300
+ )
301
+ downvote_btn.click(
302
+ downvote_last_response,
303
+ state,
304
+ [textbox, upvote_btn, downvote_btn, flag_btn],
305
+ queue=False
306
+ )
307
+ flag_btn.click(
308
+ flag_last_response,
309
+ state,
310
+ [textbox, upvote_btn, downvote_btn, flag_btn],
311
+ queue=False
312
+ )
313
+
314
+ regenerate_btn.click(
315
+ regenerate,
316
+ [state, image_process_mode],
317
+ [state, chatbot, textbox, imagebox] + btn_list,
318
+ queue=False
319
+ ).then(
320
+ http_bot,
321
+ [state, temperature, top_p, max_output_tokens],
322
+ [state, chatbot] + btn_list
323
+ )
324
+
325
+ clear_btn.click(
326
+ clear_history,
327
+ None,
328
+ [state, chatbot, textbox, imagebox] + btn_list,
329
+ queue=False
330
+ )
331
+
332
+ textbox.submit(
333
+ add_text,
334
+ [state, textbox, imagebox, image_process_mode],
335
+ [state, chatbot, textbox, imagebox] + btn_list,
336
+ queue=False
337
+ ).then(
338
+ http_bot,
339
+ [state, temperature, top_p, max_output_tokens],
340
+ [state, chatbot] + btn_list
341
+ )
342
+
343
+ submit_btn.click(
344
+ add_text,
345
+ [state, textbox, imagebox, image_process_mode],
346
+ [state, chatbot, textbox, imagebox] + btn_list,
347
+ queue=False
348
+ ).then(
349
+ http_bot,
350
+ [state, temperature, top_p, max_output_tokens],
351
+ [state, chatbot] + btn_list
352
+ )
353
+
354
+ demo.load(
355
+ load_demo,
356
+ [url_params],
357
+ state,
358
+ _js=get_window_url_params,
359
+ queue=False
360
+ )
361
+
362
+ return demo
363
+
364
+
365
+ if __name__ == "__main__":
366
+ parser = argparse.ArgumentParser()
367
+ parser.add_argument("--host", type=str, default="0.0.0.0")
368
+ parser.add_argument("--port", type=int)
369
+ parser.add_argument("--concurrency-count", type=int, default=10)
370
+ parser.add_argument("--model-list-mode", type=str, default="once",
371
+ choices=["once", "reload"])
372
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
373
+ parser.add_argument("--device", type=str, default="cuda")
374
+ parser.add_argument("--load-8bit", action="store_true")
375
+ parser.add_argument("--load-4bit", action="store_true")
376
+ parser.add_argument("--moderate", action="store_true")
377
+ parser.add_argument("--embed", action="store_true")
378
+ args = parser.parse_args()
379
+ logger.info(f"args: {args}")
380
+
381
+ model = ModelWorker(args.model_path, None, None, args.load_8bit, args.load_4bit, args.device)
382
+
383
+ logger.info(args)
384
+ demo = build_demo(args.embed)
385
+ demo.queue(
386
+ concurrency_count=args.concurrency_count,
387
+ api_open=False
388
+ ).launch(
389
+ server_name=args.host,
390
+ server_port=args.port,
391
+ share=False
392
+ )
mplug_docowl/local_serve/model_worker.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ import requests
12
+ import torch
13
+ from functools import partial
14
+
15
+ from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
16
+ from mplug_owl2.utils import (build_logger, server_error_msg,
17
+ pretty_print_semaphore)
18
+ from mplug_owl2.model.builder import load_pretrained_model
19
+ from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
20
+ from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
21
+ from transformers import TextIteratorStreamer
22
+ from threading import Thread
23
+
24
+ GB = 1 << 30
25
+
26
+ worker_id = str(uuid.uuid4())[:6]
27
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
28
+
29
+ class ModelWorker:
30
+ def __init__(self, model_path, model_base, model_name, load_8bit, load_4bit, device):
31
+ self.worker_id = worker_id
32
+ if model_path.endswith("/"):
33
+ model_path = model_path[:-1]
34
+ if model_name is None:
35
+ model_paths = model_path.split("/")
36
+ if model_paths[-1].startswith('checkpoint-'):
37
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
38
+ else:
39
+ self.model_name = model_paths[-1]
40
+ else:
41
+ self.model_name = model_name
42
+
43
+ self.device = device
44
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
45
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
46
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
47
+ self.is_multimodal = True
48
+
49
+ @torch.inference_mode()
50
+ def generate_stream(self, params):
51
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
52
+
53
+ prompt = params["prompt"]
54
+ ori_prompt = prompt
55
+ images = params.get("images", None)
56
+ num_image_tokens = 0
57
+ if images is not None and len(images) > 0 and self.is_multimodal:
58
+ if len(images) > 0:
59
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
60
+ raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
61
+
62
+ images = [load_image_from_base64(image) for image in images]
63
+ images = process_images(images, image_processor, model.config)
64
+
65
+ if type(images) is list:
66
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
67
+ else:
68
+ images = images.to(self.model.device, dtype=torch.float16)
69
+
70
+ replace_token = DEFAULT_IMAGE_TOKEN
71
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
72
+
73
+ num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
74
+ else:
75
+ images = None
76
+ image_args = {"images": images}
77
+ else:
78
+ images = None
79
+ image_args = {}
80
+
81
+ temperature = float(params.get("temperature", 1.0))
82
+ top_p = float(params.get("top_p", 1.0))
83
+ max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
84
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
85
+ stop_str = params.get("stop", None)
86
+ do_sample = True if temperature > 0.001 else False
87
+
88
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
89
+ keywords = [stop_str]
90
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
92
+
93
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
94
+
95
+ if max_new_tokens < 1:
96
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
97
+ return
98
+
99
+ thread = Thread(target=model.generate, kwargs=dict(
100
+ inputs=input_ids,
101
+ do_sample=do_sample,
102
+ temperature=temperature,
103
+ top_p=top_p,
104
+ max_new_tokens=max_new_tokens,
105
+ streamer=streamer,
106
+ stopping_criteria=[stopping_criteria],
107
+ use_cache=True,
108
+ **image_args
109
+ ))
110
+ thread.start()
111
+
112
+ generated_text = ori_prompt
113
+ for new_text in streamer:
114
+ generated_text += new_text
115
+ if generated_text.endswith(stop_str):
116
+ generated_text = generated_text[:-len(stop_str)]
117
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode()
118
+
119
+ def generate_stream_gate(self, params):
120
+ try:
121
+ for x in self.generate_stream(params):
122
+ yield x
123
+ except ValueError as e:
124
+ print("Caught ValueError:", e)
125
+ ret = {
126
+ "text": server_error_msg,
127
+ "error_code": 1,
128
+ }
129
+ yield json.dumps(ret).encode()
130
+ except torch.cuda.CudaError as e:
131
+ print("Caught torch.cuda.CudaError:", e)
132
+ ret = {
133
+ "text": server_error_msg,
134
+ "error_code": 1,
135
+ }
136
+ yield json.dumps(ret).encode()
137
+ except Exception as e:
138
+ print("Caught Unknown Error", e)
139
+ ret = {
140
+ "text": server_error_msg,
141
+ "error_code": 1,
142
+ }
143
+ yield json.dumps(ret).encode()
mplug_docowl/mm_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from mplug_docowl.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
8
+ from icecream import ic
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+
15
+ def expand2square(pil_img, background_color):
16
+ width, height = pil_img.size
17
+ if width == height:
18
+ return pil_img
19
+ elif width > height:
20
+ result = Image.new(pil_img.mode, (width, width), background_color)
21
+ result.paste(pil_img, (0, (width - height) // 2))
22
+ return result
23
+ else:
24
+ result = Image.new(pil_img.mode, (height, height), background_color)
25
+ result.paste(pil_img, ((height - width) // 2, 0))
26
+ return result
27
+
28
+
29
+ def process_images(images, image_processor, model_cfg=None):
30
+ if model_cfg is not None:
31
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32
+ else:
33
+ image_aspect_ratio = 'resize'
34
+ new_images = []
35
+ if image_aspect_ratio == 'pad':
36
+ for image in images:
37
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
38
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
39
+ new_images.append(image)
40
+ elif image_aspect_ratio == 'resize':
41
+ for image in images:
42
+ max_edge = max(image.size)
43
+ image = image.resize((max_edge, max_edge))
44
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
45
+ new_images.append(image)
46
+ else:
47
+ return image_processor(images, return_tensors='pt')['pixel_values']
48
+ if all(x.shape == new_images[0].shape for x in new_images):
49
+ new_images = torch.stack(new_images, dim=0)
50
+ return new_images
51
+
52
+
53
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
54
+ prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
55
+
56
+ def insert_separator(X, sep):
57
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
58
+
59
+ input_ids = []
60
+ offset = 0
61
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
62
+ offset = 1
63
+ input_ids.append(prompt_chunks[0][0])
64
+
65
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
66
+ input_ids.extend(x[offset:])
67
+
68
+ if return_tensors is not None:
69
+ if return_tensors == 'pt':
70
+ return torch.tensor(input_ids, dtype=torch.long)
71
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
72
+ return input_ids
73
+
74
+
75
+ def get_model_name_from_path(model_path):
76
+ model_path = model_path.strip("/")
77
+ model_paths = model_path.split("/")
78
+ if model_paths[-1].startswith('checkpoint-'):
79
+ return model_paths[-2] + "_" + model_paths[-1]
80
+ else:
81
+ return model_paths[-1]
82
+
83
+
84
+
85
+
86
+ class KeywordsStoppingCriteria(StoppingCriteria):
87
+ def __init__(self, keywords, tokenizer, input_ids):
88
+ self.keywords = keywords
89
+ self.keyword_ids = []
90
+ self.max_keyword_len = 0
91
+ for keyword in keywords:
92
+ cur_keyword_ids = tokenizer(keyword).input_ids
93
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
94
+ cur_keyword_ids = cur_keyword_ids[1:]
95
+ if len(cur_keyword_ids) > self.max_keyword_len:
96
+ self.max_keyword_len = len(cur_keyword_ids)
97
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
98
+ self.tokenizer = tokenizer
99
+ self.start_len = input_ids.shape[1]
100
+
101
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
103
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
104
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
105
+ for keyword_id in self.keyword_ids:
106
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
107
+ return True
108
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
109
+ for keyword in self.keywords:
110
+ if keyword in outputs:
111
+ return True
112
+ return False
mplug_docowl/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_mplug_docowl import MPLUGDocOwlLlamaForCausalLM
2
+ from .configuration_mplug_docowl import MPLUGDocOwlConfig
mplug_docowl/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (315 Bytes). View file
 
mplug_docowl/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.64 kB). View file
 
mplug_docowl/model/__pycache__/configuration_mplug_docowl.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
mplug_docowl/model/__pycache__/configuration_mplug_docowl2.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
mplug_docowl/model/__pycache__/convert_mplug_docowl2_weight_to_hf.cpython-310.pyc ADDED
Binary file (9.28 kB). View file
 
mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf.cpython-310.pyc ADDED
Binary file (9.12 kB). View file
 
mplug_docowl/model/__pycache__/convert_mplug_docowl_weight_to_hf_v2.cpython-310.pyc ADDED
Binary file (9.07 kB). View file
 
mplug_docowl/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc ADDED
Binary file (7.6 kB). View file
 
mplug_docowl/model/__pycache__/modeling_llama2.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
mplug_docowl/model/__pycache__/modeling_mplug_docowl.cpython-310.pyc ADDED
Binary file (9.28 kB). View file
 
mplug_docowl/model/__pycache__/modeling_mplug_docowl2.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
mplug_docowl/model/__pycache__/visual_encoder.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
mplug_docowl/model/builder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ from transformers.models.clip.image_processing_clip import CLIPImageProcessor
22
+ import torch
23
+ from mplug_docowl.model import *
24
+ from icecream import ic
25
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
26
+ kwargs = {"device_map": device_map}
27
+
28
+ if device != "cuda":
29
+ kwargs['device_map'] = {"": device}
30
+
31
+ if load_8bit:
32
+ kwargs['load_in_8bit'] = True
33
+ elif load_4bit:
34
+ kwargs['load_in_4bit'] = True
35
+ kwargs['quantization_config'] = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_compute_dtype=torch.float16,
38
+ bnb_4bit_use_double_quant=True,
39
+ bnb_4bit_quant_type='nf4'
40
+ )
41
+ else:
42
+ kwargs['torch_dtype'] = torch.float16
43
+ if 'paperowl' or 'docowl' in model_name.lower():
44
+ if model_base is not None:
45
+ # this may be mm projector only
46
+ print('Loading mPLUG-DocOwl from base model...')
47
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
48
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
49
+ model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
50
+ else:
51
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
52
+ model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
53
+ else:
54
+ # Load language model
55
+ if model_base is not None:
56
+ # PEFT model
57
+ from peft import PeftModel
58
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
59
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
60
+ print(f"Loading LoRA weights from {model_path}")
61
+ model = PeftModel.from_pretrained(model, model_path)
62
+ print(f"Merging weights")
63
+ model = model.merge_and_unload()
64
+ print('Convert to FP16...')
65
+ model.to(torch.float16)
66
+ else:
67
+ use_fast = False
68
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
69
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
70
+
71
+
72
+ # vision_tower = model.get_model().vision_model
73
+ # vision_tower.to(device=device, dtype=torch.float16)
74
+ image_processor = CLIPImageProcessor.from_pretrained(model_path)
75
+
76
+ if hasattr(model.config, "max_sequence_length"):
77
+ context_len = model.config.max_sequence_length
78
+ else:
79
+ context_len = 2048
80
+
81
+ return tokenizer, model, image_processor, context_len
mplug_docowl/model/configuration_mplug_docowl.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import copy
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
11
+ from transformers.utils import logging
12
+ from transformers.models.auto import CONFIG_MAPPING
13
+
14
+
15
+ class LlamaConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
18
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
19
+ defaults will yield a similar configuration to that of the LLaMA-7B.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+
25
+ Args:
26
+ vocab_size (`int`, *optional*, defaults to 32000):
27
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
28
+ `inputs_ids` passed when calling [`LlamaModel`]
29
+ hidden_size (`int`, *optional*, defaults to 4096):
30
+ Dimension of the hidden representations.
31
+ intermediate_size (`int`, *optional*, defaults to 11008):
32
+ Dimension of the MLP representations.
33
+ num_hidden_layers (`int`, *optional*, defaults to 32):
34
+ Number of hidden layers in the Transformer decoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer decoder.
37
+ num_key_value_heads (`int`, *optional*):
38
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
+ by meanpooling all the original heads within that group. For more details checkout [this
43
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
44
+ `num_attention_heads`.
45
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
46
+ The non-linear activation function (function or string) in the decoder.
47
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
48
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
49
+ Llama 2 up to 4096, CodeLlama up to 16384.
50
+ initializer_range (`float`, *optional*, defaults to 0.02):
51
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
53
+ The epsilon used by the rms normalization layers.
54
+ use_cache (`bool`, *optional*, defaults to `True`):
55
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
56
+ relevant if `config.is_decoder=True`.
57
+ pad_token_id (`int`, *optional*):
58
+ Padding token id.
59
+ bos_token_id (`int`, *optional*, defaults to 1):
60
+ Beginning of stream token id.
61
+ eos_token_id (`int`, *optional*, defaults to 2):
62
+ End of stream token id.
63
+ pretraining_tp (`int`, *optional*, defaults to 1):
64
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
65
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
66
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
67
+ issue](https://github.com/pytorch/pytorch/issues/76232).
68
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
69
+ Whether to tie weight embeddings
70
+ rope_theta (`float`, *optional*, defaults to 10000.0):
71
+ The base period of the RoPE embeddings.
72
+ rope_scaling (`Dict`, *optional*):
73
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
74
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
75
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
76
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
77
+ these scaling strategies behave:
78
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
79
+ experimental feature, subject to breaking API changes in future versions.
80
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
81
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
82
+
83
+
84
+ ```python
85
+ >>> from transformers import LlamaModel, LlamaConfig
86
+
87
+ >>> # Initializing a LLaMA llama-7b style configuration
88
+ >>> configuration = LlamaConfig()
89
+
90
+ >>> # Initializing a model from the llama-7b style configuration
91
+ >>> model = LlamaModel(configuration)
92
+
93
+ >>> # Accessing the model configuration
94
+ >>> configuration = model.config
95
+ ```"""
96
+ model_type = "llama"
97
+ keys_to_ignore_at_inference = ["past_key_values"]
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=32000,
102
+ hidden_size=4096,
103
+ intermediate_size=11008,
104
+ num_hidden_layers=32,
105
+ num_attention_heads=32,
106
+ num_key_value_heads=None,
107
+ hidden_act="silu",
108
+ max_position_embeddings=2048,
109
+ initializer_range=0.02,
110
+ rms_norm_eps=1e-6,
111
+ use_cache=True,
112
+ pad_token_id=None,
113
+ bos_token_id=1,
114
+ eos_token_id=2,
115
+ pretraining_tp=1,
116
+ tie_word_embeddings=False,
117
+ rope_theta=10000.0,
118
+ rope_scaling=None,
119
+ attention_bias=False,
120
+ **kwargs,
121
+ ):
122
+ self.vocab_size = vocab_size
123
+ self.max_position_embeddings = max_position_embeddings
124
+ self.hidden_size = hidden_size
125
+ self.intermediate_size = intermediate_size
126
+ self.num_hidden_layers = num_hidden_layers
127
+ self.num_attention_heads = num_attention_heads
128
+
129
+ # for backward compatibility
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+
133
+ self.num_key_value_heads = num_key_value_heads
134
+ self.hidden_act = hidden_act
135
+ self.initializer_range = initializer_range
136
+ self.rms_norm_eps = rms_norm_eps
137
+ self.pretraining_tp = pretraining_tp
138
+ self.use_cache = use_cache
139
+ self.rope_theta = rope_theta
140
+ self.rope_scaling = rope_scaling
141
+ self._rope_scaling_validation()
142
+ self.attention_bias = attention_bias
143
+
144
+ super().__init__(
145
+ pad_token_id=pad_token_id,
146
+ bos_token_id=bos_token_id,
147
+ eos_token_id=eos_token_id,
148
+ tie_word_embeddings=tie_word_embeddings,
149
+ **kwargs,
150
+ )
151
+
152
+ def _rope_scaling_validation(self):
153
+ """
154
+ Validate the `rope_scaling` configuration.
155
+ """
156
+ if self.rope_scaling is None:
157
+ return
158
+
159
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
160
+ raise ValueError(
161
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
162
+ f"got {self.rope_scaling}"
163
+ )
164
+ rope_scaling_type = self.rope_scaling.get("type", None)
165
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
166
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
167
+ raise ValueError(
168
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
169
+ )
170
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
171
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
172
+
173
+
174
+ class MplugOwlVisionConfig(PretrainedConfig):
175
+ r"""
176
+ This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
177
+ a
178
+ mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
179
+ configuration defaults will yield a similar configuration to that of the mPLUG-Owl
180
+ [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
181
+
182
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
183
+ documentation from [`PretrainedConfig`] for more information.
184
+
185
+ Args:
186
+ hidden_size (`int`, *optional*, defaults to 768):
187
+ Dimensionality of the encoder layers and the pooler layer.
188
+ intermediate_size (`int`, *optional*, defaults to 3072):
189
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
190
+ num_hidden_layers (`int`, *optional*, defaults to 12):
191
+ Number of hidden layers in the Transformer encoder.
192
+ num_attention_heads (`int`, *optional*, defaults to 12):
193
+ Number of attention heads for each attention layer in the Transformer encoder.
194
+ image_size (`int`, *optional*, defaults to 224):
195
+ The size (resolution) of each image.
196
+ patch_size (`int`, *optional*, defaults to 32):
197
+ The size (resolution) of each patch.
198
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
199
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
200
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
201
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
202
+ The epsilon used by the layer normalization layers.
203
+ attention_dropout (`float`, *optional*, defaults to 0.0):
204
+ The dropout ratio for the attention probabilities.
205
+ initializer_range (`float`, *optional*, defaults to 0.02):
206
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
207
+ initializer_factor (`float`, *optional*, defaults to 1):
208
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
209
+ testing).
210
+
211
+
212
+ ```"""
213
+
214
+ model_type = "mplug_owl_vision_model"
215
+
216
+ def __init__(
217
+ self,
218
+ hidden_size=1024,
219
+ intermediate_size=4096,
220
+ projection_dim=768,
221
+ num_hidden_layers=24,
222
+ num_attention_heads=16,
223
+ num_channels=3,
224
+ image_size=448,
225
+ patch_size=14,
226
+ hidden_act="quick_gelu",
227
+ layer_norm_eps=1e-6,
228
+ attention_dropout=0.0,
229
+ initializer_range=0.02,
230
+ initializer_factor=1.0,
231
+ use_flash_attn=False,
232
+ **kwargs,
233
+ ):
234
+ super().__init__(**kwargs)
235
+ self.hidden_size = hidden_size
236
+ self.intermediate_size = intermediate_size
237
+ self.projection_dim = projection_dim
238
+ self.num_hidden_layers = num_hidden_layers
239
+ self.num_attention_heads = num_attention_heads
240
+ self.num_channels = num_channels
241
+ self.patch_size = patch_size
242
+ self.image_size = image_size
243
+ self.initializer_range = initializer_range
244
+ self.initializer_factor = initializer_factor
245
+ self.attention_dropout = attention_dropout
246
+ self.layer_norm_eps = layer_norm_eps
247
+ self.hidden_act = hidden_act
248
+ self.use_flash_attn = use_flash_attn
249
+
250
+ @classmethod
251
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
252
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
253
+
254
+ # get the vision config dict if we are loading from MplugOwlConfig
255
+ if config_dict.get("model_type") == "mplug-owl":
256
+ config_dict = config_dict["vision_config"]
257
+
258
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
259
+ logger.warning(
260
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
261
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
262
+ )
263
+
264
+ return cls.from_dict(config_dict, **kwargs)
265
+
266
+
267
+ class MplugDocOwlHReducerConfig(PretrainedConfig):
268
+ model_type = "mplug_docowl_hreducer"
269
+
270
+ def __init__(
271
+ self,
272
+ hidden_size=1024,
273
+ initializer_range=0.02,
274
+ layer_norm_eps=1e-6,
275
+ conv_shape='1x4',
276
+ **kwargs,
277
+ ):
278
+ super().__init__(**kwargs)
279
+ self.hidden_size = hidden_size
280
+ self.initializer_range = initializer_range
281
+ self.layer_norm_eps = layer_norm_eps
282
+ self.conv_shape = conv_shape
283
+
284
+ @classmethod
285
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
286
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
287
+
288
+ # get the visual_abstractor config dict if we are loading from MplugOwlConfig
289
+ if config_dict.get("model_type") == "mplug-docowl":
290
+ config_dict = config_dict["hreducer_config"]
291
+
292
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
293
+ logger.warning(
294
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
295
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
296
+ )
297
+
298
+ return cls.from_dict(config_dict, **kwargs)
299
+
300
+ DEFAULT_VISUAL_CONFIG = {
301
+ "visual_model": MplugOwlVisionConfig().to_dict(),
302
+ "visual_hreducer": MplugDocOwlHReducerConfig().to_dict()
303
+ }
304
+
305
+ class MPLUGDocOwlConfig(LlamaConfig):
306
+ model_type = "mplug_docowl"
307
+ def __init__(self, visual_config=None, **kwargs):
308
+ if visual_config is None:
309
+ self.visual_config = DEFAULT_VISUAL_CONFIG
310
+ else:
311
+ self.visual_config = visual_config
312
+
313
+ super().__init__(
314
+ **kwargs,
315
+ )
316
+
317
+ if __name__ == "__main__":
318
+ print(MplugOwlVisionConfig().to_dict())
mplug_docowl/model/convert_mplug_docowl_weight_to_hf.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import math
18
+ import os
19
+ import shutil
20
+ import warnings
21
+
22
+ import torch
23
+
24
+ from transformers import LlamaTokenizer
25
+ from .configuration_mplug_docowl import MPLUGDocOwlConfig
26
+ from icecream import ic
27
+
28
+ try:
29
+ from transformers import LlamaTokenizerFast
30
+ except ImportError as e:
31
+ warnings.warn(e)
32
+ warnings.warn(
33
+ "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
34
+ )
35
+ LlamaTokenizerFast = None
36
+
37
+ """
38
+ Sample usage:
39
+
40
+ ```
41
+ python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
42
+ --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
43
+ ```
44
+
45
+ Thereafter, models can be loaded via:
46
+
47
+ ```py
48
+ from transformers import LlamaForCausalLM, LlamaTokenizer
49
+
50
+ model = LlamaForCausalLM.from_pretrained("/output/path")
51
+ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
52
+ ```
53
+
54
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
55
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
56
+ """
57
+
58
+ llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
59
+ llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
60
+ llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
61
+ 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
62
+ llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
63
+
64
+
65
+ def compute_intermediate_size(n):
66
+ return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
67
+
68
+
69
+ def read_json(path):
70
+ with open(path, "r") as f:
71
+ return json.load(f)
72
+
73
+
74
+ def write_json(text, path):
75
+ with open(path, "w") as f:
76
+ json.dump(text, f)
77
+
78
+
79
+ def write_model(model_path,
80
+ input_base_path,
81
+ model_size,
82
+ num_input_shards=1,
83
+ num_output_shards=2,
84
+ skip_permute=True,
85
+ norm_eps=1e-05):
86
+ # if os.path.exists(model_path):
87
+ # shutil.rmtree(model_path)
88
+ os.makedirs(model_path, exist_ok=True)
89
+ # tmp_model_path = os.path.join(model_path, "tmp")
90
+ tmp_model_path = model_path
91
+ os.makedirs(tmp_model_path, exist_ok=True)
92
+
93
+ num_shards = num_input_shards
94
+ n_layers = llama_s2layer[model_size]
95
+ n_heads = llama_s2heads[model_size]
96
+ n_heads_per_shard = n_heads // num_shards
97
+ n_dense = llama_s2dense[model_size]
98
+ n_hidden = llama_s2hidden[model_size]
99
+ hidden_per_head = n_hidden // n_heads
100
+ base = 10000.0
101
+ inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
102
+
103
+ # permute for sliced rotary
104
+ def permute(w, skip_permute=skip_permute):
105
+ if skip_permute:
106
+ return w
107
+ return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
108
+
109
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
110
+ # Load weights
111
+ if num_shards==1:
112
+ # Not sharded
113
+ # (The sharded implementation would also work, but this is simpler.)
114
+ # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
115
+ if os.path.exists(os.path.join(input_base_path, 'release')):
116
+ filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
117
+ elif input_base_path.split('/')[-1].startswith('iter_'):
118
+ iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
119
+ load_dir = '/'.join(input_base_path.split('/')[:-1])
120
+ filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
121
+ if not os.path.exists(filename):
122
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
123
+ else:
124
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
125
+ with open(tracker_filename, 'r') as f:
126
+ metastring = f.read().strip()
127
+ iteration = 'iter_{:07d}'.format(int(metastring))
128
+ filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
129
+ if not os.path.exists(filename):
130
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
131
+ original_filename = filename
132
+ loaded = torch.load(filename, map_location="cpu")['model']['language_model']
133
+
134
+ else:
135
+ # Sharded
136
+ filenames = []
137
+ for i in range(num_shards):
138
+ if os.path.exists(os.path.join(input_base_path, 'release')):
139
+ filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
140
+ else:
141
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
142
+ with open(tracker_filename, 'r') as f:
143
+ metastring = f.read().strip()
144
+ iteration = 'iter_{:07d}'.format(int(metastring))
145
+ filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
146
+ if not os.path.exists(filename):
147
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
148
+ filenames.append(filename)
149
+ loaded = [
150
+ torch.load(filenames[i], map_location="cpu")['model']['language_model']
151
+ for i in range(num_shards)
152
+ ]
153
+
154
+ print('Llama-Megatron Loaded!')
155
+ param_count = 0
156
+ index_dict = {"weight_map": {}}
157
+
158
+ print(f'Weighted Converting for {n_layers} layers...')
159
+ for layer_i in range(n_layers):
160
+ print(layer_i)
161
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
162
+ if num_shards == 1:
163
+ # Unsharded
164
+ state_dict = {
165
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
166
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
167
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
168
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
169
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
170
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
171
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
172
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
173
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
174
+ f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
175
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
176
+ f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
177
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
178
+ }
179
+ else:
180
+ raise NotImplemented
181
+
182
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
183
+ for k, v in state_dict.items():
184
+ index_dict["weight_map"][k] = filename
185
+ param_count += v.numel()
186
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
187
+ print(f'Sharded file saved to {filename}')
188
+
189
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
190
+ if num_shards==1:
191
+ # Unsharded
192
+ state_dict = {
193
+ "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
194
+ "model.norm.weight": loaded['encoder']['norm.weight'],
195
+ "lm_head.weight": loaded['encoder']['lm_head.weight'],
196
+ }
197
+ else:
198
+ state_dict = {
199
+ "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
200
+ "model.norm.weight": loaded[0]['encoder']['norm.weight'],
201
+ "lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
202
+ }
203
+
204
+
205
+ loaded_all = torch.load(original_filename, map_location="cpu")['model']
206
+ # Vision Part
207
+ state_dict.update({
208
+ "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
209
+ "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
210
+ "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
211
+ "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
212
+ "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
213
+ "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
214
+ "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
215
+ })
216
+ for v_layer_idx in range(24):
217
+ state_dict.update({
218
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
219
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
220
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
221
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
222
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
223
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
224
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
225
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
226
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
227
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
228
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
229
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
230
+ })
231
+
232
+ # Vision2Text Part: HReducer
233
+ state_dict.update({
234
+ "model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
235
+ "model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
236
+ "model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
237
+ "model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
238
+ "model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
239
+ })
240
+ # reducer_before conv (layer 0) + gleu (layer 1)
241
+ state_dict.update({
242
+ f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
243
+ f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
244
+ })
245
+ # reducer conv
246
+ state_dict.update({
247
+ f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
248
+ f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
249
+ })
250
+
251
+ for k, v in state_dict.items():
252
+ # ic(k, v)
253
+ index_dict["weight_map"][k] = filename
254
+ param_count += v.numel()
255
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
256
+
257
+ # Write configs
258
+ index_dict["metadata"] = {"total_size": param_count * 2}
259
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
260
+
261
+ config = MPLUGDocOwlConfig()
262
+ config.save_pretrained(tmp_model_path)
263
+
264
+ # Make space so we can load the model properly now.
265
+ del state_dict
266
+ del loaded
267
+ del loaded_all
268
+ gc.collect()
269
+
270
+ def write_tokenizer(tokenizer_path, input_tokenizer_path):
271
+ # Initialize the tokenizer based on the `spm` model
272
+ tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
273
+ print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
274
+ tokenizer = tokenizer_class(input_tokenizer_path)
275
+ tokenizer.save_pretrained(tokenizer_path)
276
+
277
+
278
+ def main():
279
+ parser = argparse.ArgumentParser()
280
+ parser.add_argument(
281
+ "--input_dir",
282
+ help="Location of LLaMA_Megatron weights",
283
+ )
284
+ parser.add_argument(
285
+ "--model_size",
286
+ type=int,
287
+ default=7,
288
+ choices=[7, 13, 30, 65, 70],
289
+ )
290
+ parser.add_argument(
291
+ "--num_input_shards",
292
+ type=int,
293
+ default=1,
294
+ )
295
+ parser.add_argument(
296
+ "--num_output_shards",
297
+ type=int,
298
+ default=1,
299
+ )
300
+ parser.add_argument('--skip_permute', action='store_true')
301
+
302
+ parser.add_argument(
303
+ "--output_dir",
304
+ help="Location to write HF model and tokenizer",
305
+ )
306
+
307
+ args = parser.parse_args()
308
+ write_model(
309
+ model_path=args.output_dir,
310
+ input_base_path=args.input_dir,
311
+ model_size=args.model_size,
312
+ num_input_shards=args.num_input_shards,
313
+ num_output_shards=args.num_output_shards,
314
+ skip_permute=args.skip_permute
315
+ )
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()
mplug_docowl/model/convert_mplug_docowl_weight_to_hf_v2.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import math
18
+ import os
19
+ import shutil
20
+ import warnings
21
+
22
+ import torch
23
+
24
+ from transformers import LlamaTokenizer
25
+ from .configuration_mplug_docowl import MPLUGDocOwlConfig
26
+ from icecream import ic
27
+
28
+ try:
29
+ from transformers import LlamaTokenizerFast
30
+ except ImportError as e:
31
+ warnings.warn(e)
32
+ warnings.warn(
33
+ "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
34
+ )
35
+ LlamaTokenizerFast = None
36
+
37
+ """
38
+ Sample usage:
39
+
40
+ ```
41
+ python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
42
+ --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
43
+ ```
44
+
45
+ Thereafter, models can be loaded via:
46
+
47
+ ```py
48
+ from transformers import LlamaForCausalLM, LlamaTokenizer
49
+
50
+ model = LlamaForCausalLM.from_pretrained("/output/path")
51
+ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
52
+ ```
53
+
54
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
55
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
56
+ """
57
+
58
+ llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
59
+ llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
60
+ llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
61
+ 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
62
+ llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
63
+
64
+
65
+ def compute_intermediate_size(n):
66
+ return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
67
+
68
+
69
+ def read_json(path):
70
+ with open(path, "r") as f:
71
+ return json.load(f)
72
+
73
+
74
+ def write_json(text, path):
75
+ with open(path, "w") as f:
76
+ json.dump(text, f)
77
+
78
+
79
+ def write_model(model_path,
80
+ input_base_path,
81
+ model_size,
82
+ num_input_shards=1,
83
+ num_output_shards=2,
84
+ skip_permute=True,
85
+ norm_eps=1e-05):
86
+ # if os.path.exists(model_path):
87
+ # shutil.rmtree(model_path)
88
+ os.makedirs(model_path, exist_ok=True)
89
+ # tmp_model_path = os.path.join(model_path, "tmp")
90
+ tmp_model_path = model_path
91
+ os.makedirs(tmp_model_path, exist_ok=True)
92
+
93
+ num_shards = num_input_shards
94
+ n_layers = llama_s2layer[model_size]
95
+ n_heads = llama_s2heads[model_size]
96
+ n_heads_per_shard = n_heads // num_shards
97
+ n_dense = llama_s2dense[model_size]
98
+ n_hidden = llama_s2hidden[model_size]
99
+ hidden_per_head = n_hidden // n_heads
100
+ base = 10000.0
101
+ inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
102
+
103
+ # permute for sliced rotary
104
+ def permute(w, skip_permute=skip_permute):
105
+ if skip_permute:
106
+ return w
107
+ return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
108
+
109
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
110
+ # Load weights
111
+ if num_shards==1:
112
+ # Not sharded
113
+ # (The sharded implementation would also work, but this is simpler.)
114
+ # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
115
+ if os.path.exists(os.path.join(input_base_path, 'release')):
116
+ filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
117
+ elif input_base_path.split('/')[-1].startswith('iter_'):
118
+ iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
119
+ load_dir = '/'.join(input_base_path.split('/')[:-1])
120
+ filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
121
+ if not os.path.exists(filename):
122
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
123
+ else:
124
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
125
+ with open(tracker_filename, 'r') as f:
126
+ metastring = f.read().strip()
127
+ iteration = 'iter_{:07d}'.format(int(metastring))
128
+ filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
129
+ if not os.path.exists(filename):
130
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
131
+ original_filename = filename
132
+ loaded = torch.load(filename, map_location="cpu")['model']['language_model']
133
+
134
+ else:
135
+ # Sharded
136
+ filenames = []
137
+ for i in range(num_shards):
138
+ if os.path.exists(os.path.join(input_base_path, 'release')):
139
+ filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
140
+ else:
141
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
142
+ with open(tracker_filename, 'r') as f:
143
+ metastring = f.read().strip()
144
+ iteration = 'iter_{:07d}'.format(int(metastring))
145
+ filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
146
+ if not os.path.exists(filename):
147
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
148
+ filenames.append(filename)
149
+ loaded = [
150
+ torch.load(filenames[i], map_location="cpu")['model']['language_model']
151
+ for i in range(num_shards)
152
+ ]
153
+
154
+ print('Llama-Megatron Loaded!')
155
+ param_count = 0
156
+ index_dict = {"weight_map": {}}
157
+ state_dict = {}
158
+ print(f'Weighted Converting for {n_layers} layers...')
159
+ for layer_i in range(n_layers):
160
+ print(layer_i)
161
+ # filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
162
+ if num_shards == 1:
163
+ # Unsharded
164
+ state_dict.update({
165
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
166
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
167
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
168
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
169
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
170
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
171
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
172
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
173
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
174
+ f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
175
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
176
+ f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
177
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
178
+ })
179
+ else:
180
+ raise NotImplemented
181
+
182
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
183
+ for k, v in state_dict.items():
184
+ index_dict["weight_map"][k] = filename
185
+ param_count += v.numel()
186
+ # torch.save(state_dict, os.path.join(tmp_model_path, filename))
187
+ # print(f'Sharded file saved to {filename}')
188
+
189
+ # filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
190
+ filename = "pytorch_model.bin"
191
+ if num_shards==1:
192
+ # Unsharded
193
+ state_dict.update({
194
+ "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
195
+ "model.norm.weight": loaded['encoder']['norm.weight'],
196
+ "lm_head.weight": loaded['encoder']['lm_head.weight'],
197
+ })
198
+ else:
199
+ state_dict.update({
200
+ "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
201
+ "model.norm.weight": loaded[0]['encoder']['norm.weight'],
202
+ "lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
203
+ })
204
+
205
+ loaded_all = torch.load(original_filename, map_location="cpu")['model']
206
+ # Vision Part
207
+ state_dict.update({
208
+ "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
209
+ "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
210
+ "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
211
+ "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
212
+ "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
213
+ "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
214
+ "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
215
+ })
216
+ for v_layer_idx in range(24):
217
+ state_dict.update({
218
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
219
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
220
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
221
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
222
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
223
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
224
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
225
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
226
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
227
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
228
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
229
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
230
+ })
231
+
232
+ # Vision2Text Part: HReducer
233
+ state_dict.update({
234
+ "model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
235
+ "model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
236
+ "model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
237
+ "model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
238
+ "model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
239
+ })
240
+ # reducer_before conv (layer 0) + gleu (layer 1)
241
+ state_dict.update({
242
+ f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
243
+ f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
244
+ })
245
+ # reducer conv
246
+ state_dict.update({
247
+ f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
248
+ f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
249
+ })
250
+
251
+ for k, v in state_dict.items():
252
+ # ic(k, v)
253
+ index_dict["weight_map"][k] = filename
254
+ param_count += v.numel()
255
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
256
+ print(f'save to {os.path.join(tmp_model_path, filename)}')
257
+
258
+ # Write configs
259
+ index_dict["metadata"] = {"total_size": param_count * 2}
260
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
261
+
262
+ config = MPLUGDocOwlConfig()
263
+ config.save_pretrained(tmp_model_path)
264
+
265
+ # Make space so we can load the model properly now.
266
+ del state_dict
267
+ del loaded
268
+ del loaded_all
269
+ gc.collect()
270
+
271
+ def write_tokenizer(tokenizer_path, input_tokenizer_path):
272
+ # Initialize the tokenizer based on the `spm` model
273
+ tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
274
+ print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
275
+ tokenizer = tokenizer_class(input_tokenizer_path)
276
+ tokenizer.save_pretrained(tokenizer_path)
277
+
278
+
279
+ def main():
280
+ parser = argparse.ArgumentParser()
281
+ parser.add_argument(
282
+ "--input_dir",
283
+ help="Location of LLaMA_Megatron weights",
284
+ )
285
+ parser.add_argument(
286
+ "--model_size",
287
+ type=int,
288
+ default=7,
289
+ choices=[7, 13, 30, 65, 70],
290
+ )
291
+ parser.add_argument(
292
+ "--num_input_shards",
293
+ type=int,
294
+ default=1,
295
+ )
296
+ parser.add_argument(
297
+ "--num_output_shards",
298
+ type=int,
299
+ default=1,
300
+ )
301
+ parser.add_argument('--skip_permute', action='store_true')
302
+
303
+ parser.add_argument(
304
+ "--output_dir",
305
+ help="Location to write HF model and tokenizer",
306
+ )
307
+
308
+ args = parser.parse_args()
309
+ write_model(
310
+ model_path=args.output_dir,
311
+ input_base_path=args.input_dir,
312
+ model_size=args.model_size,
313
+ num_input_shards=args.num_input_shards,
314
+ num_output_shards=args.num_output_shards,
315
+ skip_permute=args.skip_permute
316
+ )
317
+
318
+
319
+ if __name__ == "__main__":
320
+ main()
mplug_docowl/model/modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import torch
17
+
18
+
19
+ class AttentionMaskConverter:
20
+ """
21
+ A utility attention mask class that allows one to:
22
+ - Create a causal 4d mask
23
+ - Create a causal 4d mask with slided window
24
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
25
+ key_value_length) that can be multiplied with attention scores
26
+
27
+ Parameters:
28
+ is_causal (`bool`):
29
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
30
+
31
+ sliding_window (`int`, *optional*):
32
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
33
+ """
34
+
35
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
36
+ self.is_causal = is_causal
37
+ self.sliding_window = sliding_window
38
+
39
+ if self.sliding_window is not None and self.sliding_window <= 0:
40
+ raise ValueError(
41
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
42
+ )
43
+
44
+ def to_causal_4d(
45
+ self,
46
+ batch_size: int,
47
+ query_length: int,
48
+ key_value_length: int,
49
+ dtype: torch.dtype = torch.float32,
50
+ device: Union[torch.device, "str"] = "cpu",
51
+ ) -> torch.Tensor:
52
+ """
53
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
54
+ bias to upper right hand triangular matrix (causal mask).
55
+ """
56
+ if not self.is_causal:
57
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
58
+
59
+ # If shape is not cached, create a new causal mask and cache it
60
+ input_shape = (batch_size, query_length)
61
+ past_key_values_length = key_value_length - query_length
62
+
63
+ # create causal mask
64
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
65
+ causal_4d_mask = None
66
+ if input_shape[-1] > 1 or self.sliding_window is not None:
67
+ causal_4d_mask = self._make_causal_mask(
68
+ input_shape,
69
+ dtype,
70
+ device=device,
71
+ past_key_values_length=past_key_values_length,
72
+ sliding_window=self.sliding_window,
73
+ )
74
+
75
+ return causal_4d_mask
76
+
77
+ def to_4d(
78
+ self,
79
+ attention_mask_2d: torch.Tensor,
80
+ query_length: int,
81
+ key_value_length: Optional[int] = None,
82
+ dtype: torch.dtype = torch.float32,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
86
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
87
+ causal, a causal mask will be added.
88
+ """
89
+ input_shape = (attention_mask_2d.shape[0], query_length)
90
+
91
+ # create causal mask
92
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
93
+ causal_4d_mask = None
94
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
95
+ if key_value_length is None:
96
+ raise ValueError(
97
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
98
+ )
99
+
100
+ past_key_values_length = key_value_length - query_length
101
+ causal_4d_mask = self._make_causal_mask(
102
+ input_shape,
103
+ dtype,
104
+ device=attention_mask_2d.device,
105
+ past_key_values_length=past_key_values_length,
106
+ sliding_window=self.sliding_window,
107
+ )
108
+ elif self.sliding_window is not None:
109
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
110
+
111
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
112
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
113
+ attention_mask_2d.device
114
+ )
115
+ expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
116
+
117
+ return expanded_4d_mask
118
+
119
+ @staticmethod
120
+ def _make_causal_mask(
121
+ input_ids_shape: torch.Size,
122
+ dtype: torch.dtype,
123
+ device: torch.device,
124
+ past_key_values_length: int = 0,
125
+ sliding_window: Optional[int] = None,
126
+ ):
127
+ """
128
+ Make causal mask used for bi-directional self-attention.
129
+ """
130
+ bsz, tgt_len = input_ids_shape
131
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
132
+ mask_cond = torch.arange(mask.size(-1), device=device)
133
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
134
+
135
+ mask = mask.to(dtype)
136
+
137
+ if past_key_values_length > 0:
138
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
139
+
140
+ # add lower triangular sliding window mask if necessary
141
+ if sliding_window is not None:
142
+ diagonal = past_key_values_length - sliding_window + 1
143
+
144
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
145
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
146
+
147
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
148
+
149
+ @staticmethod
150
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
151
+ """
152
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
153
+ """
154
+ bsz, src_len = mask.size()
155
+ tgt_len = tgt_len if tgt_len is not None else src_len
156
+
157
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
158
+
159
+ inverted_mask = 1.0 - expanded_mask
160
+
161
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
162
+
163
+
164
+ def _prepare_4d_causal_attention_mask(
165
+ attention_mask: Optional[torch.Tensor],
166
+ input_shape: Union[torch.Size, Tuple, List],
167
+ inputs_embeds: torch.Tensor,
168
+ past_key_values_length: int,
169
+ sliding_window: Optional[int] = None,
170
+ ):
171
+ """
172
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
173
+ `(batch_size, key_value_length)`
174
+
175
+ Args:
176
+ attention_mask (`torch.Tensor` or `None`):
177
+ A 2D attention mask of shape `(batch_size, key_value_length)`
178
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
179
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
180
+ inputs_embeds (`torch.Tensor`):
181
+ The embedded inputs as a torch Tensor.
182
+ past_key_values_length (`int`):
183
+ The length of the key value cache.
184
+ sliding_window (`int`, *optional*):
185
+ If the model uses windowed attention, a sliding window should be passed.
186
+ """
187
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
188
+
189
+ key_value_length = input_shape[-1] + past_key_values_length
190
+
191
+ # 4d mask is passed through the layers
192
+ if attention_mask is not None:
193
+ attention_mask = attn_mask_converter.to_4d(
194
+ attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
195
+ )
196
+ else:
197
+ attention_mask = attn_mask_converter.to_causal_4d(
198
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
199
+ )
200
+
201
+ return attention_mask
202
+
203
+
204
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
205
+ """
206
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
207
+ `(batch_size, key_value_length)`
208
+
209
+ Args:
210
+ mask (`torch.Tensor` or `None`):
211
+ A 2D attention mask of shape `(batch_size, key_value_length)`
212
+ dtype (`torch.dtype`):
213
+ The torch dtype the created mask shall have.
214
+ tgt_len (`int`):
215
+ The target length or query length the created mask shall have.
216
+ """
217
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
218
+
219
+
220
+ def _create_4d_causal_attention_mask(
221
+ input_shape: Union[torch.Size, Tuple, List],
222
+ dtype: torch.dtype,
223
+ device: torch.device,
224
+ past_key_values_length: int = 0,
225
+ sliding_window: Optional[int] = None,
226
+ ):
227
+ """
228
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
229
+
230
+ Args:
231
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
232
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
233
+ dtype (`torch.dtype`):
234
+ The torch dtype the created mask shall have.
235
+ device (`int`):
236
+ The torch device the created mask shall have.
237
+ sliding_window (`int`, *optional*):
238
+ If the model uses windowed attention, a sliding window should be passed.
239
+ """
240
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
241
+
242
+ key_value_length = past_key_values_length + input_shape[-1]
243
+ attention_mask = attn_mask_converter.to_causal_4d(
244
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
245
+ )
246
+
247
+ return attention_mask
mplug_docowl/model/modeling_llama2.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+
11
+ import transformers
12
+ from transformers.models.llama.modeling_llama import *
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.utils import logging
15
+
16
+ from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
17
+ from .configuration_mplug_docowl import LlamaConfig
18
+
19
+ class MultiwayNetwork(nn.Module):
20
+
21
+ def __init__(self, module_provider, num_multiway=2):
22
+ super(MultiwayNetwork, self).__init__()
23
+
24
+ self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
25
+
26
+ def forward(self, hidden_states, multiway_indices):
27
+
28
+ if len(self.multiway) == 1:
29
+ return self.multiway[0](hidden_states)
30
+
31
+ output_hidden_states = torch.empty_like(hidden_states)
32
+
33
+ for idx, subway in enumerate(self.multiway):
34
+ local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
35
+ hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
36
+ if hidden.numel():
37
+ output = subway(hidden)
38
+ if isinstance(output, tuple):
39
+ output = output[0]
40
+ output = output.squeeze(1)
41
+ output_hidden_states[local_indices] = output
42
+
43
+ return output_hidden_states.contiguous()
44
+
45
+
46
+ class LlamaAttention(nn.Module):
47
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
48
+
49
+ def __init__(self, config: LlamaConfig):
50
+ super().__init__()
51
+ self.config = config
52
+ self.hidden_size = config.hidden_size
53
+ self.num_heads = config.num_attention_heads
54
+ self.head_dim = self.hidden_size // self.num_heads
55
+ self.num_key_value_heads = config.num_key_value_heads
56
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
57
+ self.max_position_embeddings = config.max_position_embeddings
58
+ self.rope_theta = config.rope_theta
59
+
60
+ if (self.head_dim * self.num_heads) != self.hidden_size:
61
+ raise ValueError(
62
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
63
+ f" and `num_heads`: {self.num_heads})."
64
+ )
65
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
66
+ self.k_proj = MultiwayNetwork(module_provider=partial(
67
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
68
+ )
69
+ self.v_proj = MultiwayNetwork(module_provider=partial(
70
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
71
+ )
72
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
73
+ self._init_rope()
74
+
75
+ def _init_rope(self):
76
+ if self.config.rope_scaling is None:
77
+ self.rotary_emb = LlamaRotaryEmbedding(
78
+ self.head_dim,
79
+ max_position_embeddings=self.max_position_embeddings,
80
+ base=self.rope_theta,
81
+ )
82
+ else:
83
+ scaling_type = self.config.rope_scaling["type"]
84
+ scaling_factor = self.config.rope_scaling["factor"]
85
+ if scaling_type == "linear":
86
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
87
+ self.head_dim,
88
+ max_position_embeddings=self.max_position_embeddings,
89
+ scaling_factor=scaling_factor,
90
+ base=self.rope_theta,
91
+ )
92
+ elif scaling_type == "dynamic":
93
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
94
+ self.head_dim,
95
+ max_position_embeddings=self.max_position_embeddings,
96
+ scaling_factor=scaling_factor,
97
+ base=self.rope_theta,
98
+ )
99
+ else:
100
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
101
+
102
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
103
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
104
+
105
+ def forward(
106
+ self,
107
+ hidden_states: torch.Tensor,
108
+ modality_indicators: torch.Tensor,
109
+ attention_mask: Optional[torch.Tensor] = None,
110
+ position_ids: Optional[torch.LongTensor] = None,
111
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
112
+ output_attentions: bool = False,
113
+ use_cache: bool = False,
114
+ padding_mask: Optional[torch.LongTensor] = None,
115
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
116
+ bsz, q_len, _ = hidden_states.size()
117
+
118
+ query_states = self.q_proj(hidden_states, )
119
+ key_states = self.k_proj(hidden_states, modality_indicators)
120
+ value_states = self.v_proj(hidden_states, modality_indicators)
121
+
122
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
123
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
124
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
125
+
126
+ kv_seq_len = key_states.shape[-2]
127
+ if past_key_value is not None:
128
+ kv_seq_len += past_key_value[0].shape[-2]
129
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
130
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
131
+
132
+ if past_key_value is not None:
133
+ # reuse k, v, self_attention
134
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
135
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
136
+
137
+ past_key_value = (key_states, value_states) if use_cache else None
138
+
139
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
140
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
141
+
142
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
143
+
144
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
145
+ raise ValueError(
146
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
147
+ f" {attn_weights.size()}"
148
+ )
149
+
150
+ if attention_mask is not None:
151
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
152
+ raise ValueError(
153
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
154
+ )
155
+ attn_weights = attn_weights + attention_mask
156
+
157
+ # upcast attention to fp32
158
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
159
+ attn_output = torch.matmul(attn_weights, value_states)
160
+
161
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
162
+ raise ValueError(
163
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
164
+ f" {attn_output.size()}"
165
+ )
166
+
167
+ attn_output = attn_output.transpose(1, 2).contiguous()
168
+
169
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
170
+
171
+ attn_output = self.o_proj(attn_output)
172
+
173
+ if not output_attentions:
174
+ attn_weights = None
175
+
176
+ return attn_output, attn_weights, past_key_value
177
+
178
+
179
+
180
+ class LlamaDecoderLayer(nn.Module):
181
+ def __init__(self, config: LlamaConfig):
182
+ super().__init__()
183
+ self.hidden_size = config.hidden_size
184
+ self.self_attn = LlamaAttention(config=config)
185
+ self.mlp = LlamaMLP(config)
186
+ self.input_layernorm = MultiwayNetwork(module_provider=partial(
187
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
188
+ ))
189
+ self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
190
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
191
+ ))
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ modality_indicators: torch.Tensor = None,
197
+ attention_mask: Optional[torch.Tensor] = None,
198
+ position_ids: Optional[torch.LongTensor] = None,
199
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
200
+ output_attentions: Optional[bool] = False,
201
+ use_cache: Optional[bool] = False,
202
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
203
+ """
204
+ Args:
205
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
206
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
207
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
208
+ output_attentions (`bool`, *optional*):
209
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
210
+ returned tensors for more detail.
211
+ use_cache (`bool`, *optional*):
212
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
213
+ (see `past_key_values`).
214
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
215
+ """
216
+
217
+ residual = hidden_states
218
+
219
+ hidden_states = self.input_layernorm(hidden_states, modality_indicators)
220
+
221
+ # Self Attention
222
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
223
+ hidden_states=hidden_states,
224
+ modality_indicators=modality_indicators,
225
+ attention_mask=attention_mask,
226
+ position_ids=position_ids,
227
+ past_key_value=past_key_value,
228
+ output_attentions=output_attentions,
229
+ use_cache=use_cache,
230
+ )
231
+ hidden_states = residual + hidden_states
232
+
233
+ # Fully Connected
234
+ residual = hidden_states
235
+ hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
236
+ hidden_states = self.mlp(hidden_states)
237
+ hidden_states = residual + hidden_states
238
+
239
+ outputs = (hidden_states,)
240
+
241
+ if output_attentions:
242
+ outputs += (self_attn_weights,)
243
+
244
+ if use_cache:
245
+ outputs += (present_key_value,)
246
+
247
+ return outputs
248
+
249
+
250
+ def model_forward(
251
+ self,
252
+ input_ids: torch.LongTensor = None,
253
+ modality_indicators: torch.Tensor = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ position_ids: Optional[torch.LongTensor] = None,
256
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
257
+ inputs_embeds: Optional[torch.FloatTensor] = None,
258
+ use_cache: Optional[bool] = None,
259
+ output_attentions: Optional[bool] = None,
260
+ output_hidden_states: Optional[bool] = None,
261
+ return_dict: Optional[bool] = None,
262
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
263
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
264
+ output_hidden_states = (
265
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
266
+ )
267
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
268
+
269
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
270
+
271
+ # retrieve input_ids and inputs_embeds
272
+ if input_ids is not None and inputs_embeds is not None:
273
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
274
+ elif input_ids is not None:
275
+ batch_size, seq_length = input_ids.shape
276
+ elif inputs_embeds is not None:
277
+ batch_size, seq_length, _ = inputs_embeds.shape
278
+ else:
279
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
280
+
281
+ seq_length_with_past = seq_length
282
+ past_key_values_length = 0
283
+
284
+ if past_key_values is not None:
285
+ past_key_values_length = past_key_values[0][0].shape[2]
286
+ seq_length_with_past = seq_length_with_past + past_key_values_length
287
+
288
+ if position_ids is None:
289
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
290
+ position_ids = torch.arange(
291
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
292
+ )
293
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
294
+ else:
295
+ position_ids = position_ids.view(-1, seq_length).long()
296
+
297
+ if inputs_embeds is None:
298
+ inputs_embeds = self.embed_tokens(input_ids)
299
+ # embed positions
300
+ if attention_mask is None:
301
+ attention_mask = torch.ones(
302
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
303
+ )
304
+ attention_mask = self._prepare_decoder_attention_mask(
305
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
306
+ )
307
+
308
+ hidden_states = inputs_embeds
309
+
310
+ if self.gradient_checkpointing and self.training:
311
+ if use_cache:
312
+ logger.warning_once(
313
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
314
+ )
315
+ use_cache = False
316
+
317
+ # decoder layers
318
+ all_hidden_states = () if output_hidden_states else None
319
+ all_self_attns = () if output_attentions else None
320
+ next_decoder_cache = () if use_cache else None
321
+
322
+ for idx, decoder_layer in enumerate(self.layers):
323
+ if output_hidden_states:
324
+ all_hidden_states += (hidden_states,)
325
+
326
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
327
+
328
+ if self.gradient_checkpointing and self.training:
329
+
330
+ def create_custom_forward(module):
331
+ def custom_forward(*inputs):
332
+ # None for past_key_value
333
+ return module(*inputs, past_key_value, output_attentions)
334
+
335
+ return custom_forward
336
+
337
+ layer_outputs = torch.utils.checkpoint.checkpoint(
338
+ create_custom_forward(decoder_layer),
339
+ hidden_states,
340
+ modality_indicators,
341
+ attention_mask,
342
+ position_ids,
343
+ )
344
+ else:
345
+ layer_outputs = decoder_layer(
346
+ hidden_states,
347
+ modality_indicators=modality_indicators,
348
+ attention_mask=attention_mask,
349
+ position_ids=position_ids,
350
+ past_key_value=past_key_value,
351
+ output_attentions=output_attentions,
352
+ use_cache=use_cache,
353
+ )
354
+
355
+ hidden_states = layer_outputs[0]
356
+
357
+ if use_cache:
358
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
359
+
360
+ if output_attentions:
361
+ all_self_attns += (layer_outputs[1],)
362
+
363
+ hidden_states = self.norm(hidden_states)
364
+
365
+ # add hidden states from the last decoder layer
366
+ if output_hidden_states:
367
+ all_hidden_states += (hidden_states,)
368
+
369
+ next_cache = next_decoder_cache if use_cache else None
370
+ if not return_dict:
371
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
372
+ return BaseModelOutputWithPast(
373
+ last_hidden_state=hidden_states,
374
+ past_key_values=next_cache,
375
+ hidden_states=all_hidden_states,
376
+ attentions=all_self_attns,
377
+ )
378
+
379
+
380
+ def causal_model_forward(
381
+ self,
382
+ input_ids: torch.LongTensor = None,
383
+ modality_indicators: torch.Tensor = None,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ position_ids: Optional[torch.LongTensor] = None,
386
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
387
+ inputs_embeds: Optional[torch.FloatTensor] = None,
388
+ labels: Optional[torch.LongTensor] = None,
389
+ use_cache: Optional[bool] = None,
390
+ output_attentions: Optional[bool] = None,
391
+ output_hidden_states: Optional[bool] = None,
392
+ return_dict: Optional[bool] = None,
393
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
394
+ r"""
395
+ Args:
396
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
397
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
398
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
399
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
400
+
401
+ Returns:
402
+
403
+ Example:
404
+
405
+ ```python
406
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
407
+
408
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
409
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
410
+
411
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
412
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
413
+
414
+ >>> # Generate
415
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
416
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
417
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
418
+ ```"""
419
+
420
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
421
+ output_hidden_states = (
422
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
423
+ )
424
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
425
+
426
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
427
+ outputs = self.model(
428
+ input_ids=input_ids,
429
+ modality_indicators=modality_indicators,
430
+ attention_mask=attention_mask,
431
+ position_ids=position_ids,
432
+ past_key_values=past_key_values,
433
+ inputs_embeds=inputs_embeds,
434
+ use_cache=use_cache,
435
+ output_attentions=output_attentions,
436
+ output_hidden_states=output_hidden_states,
437
+ return_dict=return_dict,
438
+ )
439
+
440
+ hidden_states = outputs[0]
441
+ if self.config.pretraining_tp > 1:
442
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
443
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
444
+ logits = torch.cat(logits, dim=-1)
445
+ else:
446
+ logits = self.lm_head(hidden_states)
447
+ logits = logits.float()
448
+
449
+ loss = None
450
+ if labels is not None:
451
+ # Shift so that tokens < n predict n
452
+ shift_logits = logits[..., :-1, :].contiguous()
453
+ shift_labels = labels[..., 1:].contiguous()
454
+ # Flatten the tokens
455
+ loss_fct = CrossEntropyLoss()
456
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
457
+ shift_labels = shift_labels.view(-1)
458
+ # Enable model parallelism
459
+ shift_labels = shift_labels.to(shift_logits.device)
460
+ loss = loss_fct(shift_logits, shift_labels)
461
+
462
+ if not return_dict:
463
+ output = (logits,) + outputs[1:]
464
+ return (loss,) + output if loss is not None else output
465
+
466
+ return CausalLMOutputWithPast(
467
+ loss=loss,
468
+ logits=logits,
469
+ past_key_values=outputs.past_key_values,
470
+ hidden_states=outputs.hidden_states,
471
+ attentions=outputs.attentions,
472
+ )
473
+
474
+ def replace_llama_modality_adaptive():
475
+ transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
476
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
477
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
478
+ transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
479
+ transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
480
+
481
+
482
+ if __name__ == "__main__":
483
+ replace_llama_modality_adaptive()
484
+ config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
485
+ model = transformers.LlamaForCausalLM(config)
486
+ print(model)
mplug_docowl/model/modeling_mplug_docowl.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+
25
+ from .configuration_mplug_docowl import (MPLUGDocOwlConfig, MplugOwlVisionConfig, MplugDocOwlHReducerConfig)
26
+ from .visual_encoder import MplugOwlVisionModel, MplugDocOwlHReducerModel
27
+ from .modeling_llama2 import replace_llama_modality_adaptive
28
+ from mplug_docowl.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
29
+ from icecream import ic
30
+
31
+ class MPLUGDocOwlMetaModel:
32
+ def __init__(self, config):
33
+ super(MPLUGDocOwlMetaModel, self).__init__(config)
34
+ self.vision_model = MplugOwlVisionModel(
35
+ MplugOwlVisionConfig(**config.visual_config["visual_model"])
36
+ )
37
+
38
+ self.vision2text = MplugDocOwlHReducerModel(
39
+ MplugDocOwlHReducerConfig(**config.visual_config["visual_hreducer"]), config.hidden_size
40
+ )
41
+
42
+ def get_vision_tower(self):
43
+ vision_model = getattr(self, 'vision_model', None)
44
+ if type(vision_model) is list:
45
+ vision_model = vision_model[0]
46
+ return vision_model
47
+
48
+ def get_vision2text(self):
49
+ vision2text = getattr(self, 'vision2text', None)
50
+ if type(vision2text) is list:
51
+ vision2text = vision2text[0]
52
+ return vision2text
53
+
54
+ class MPLUGDocOwlMetaForCausalLM(ABC):
55
+ @abstractmethod
56
+ def get_model(self):
57
+ pass
58
+
59
+ def encode_images(self, images, patch_positions):
60
+ image_features = self.get_model().vision_model(images).last_hidden_state
61
+ image_features = self.get_model().vision2text(encoder_hidden_states=image_features)
62
+ return image_features
63
+
64
+ def prepare_inputs_labels_for_multimodal(
65
+ self, input_ids, attention_mask, past_key_values, labels, images, patch_positions
66
+ ):
67
+ if images is None or input_ids.shape[1] == 1:
68
+ if past_key_values is not None and images is not None and input_ids.shape[1] == 1:
69
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
70
+ multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
71
+ return input_ids, multiway_indices, attention_mask, past_key_values, None, labels
72
+
73
+ if type(images) is list or images.ndim == 5:
74
+ concat_images = torch.cat([image for image in images], dim=0)
75
+ image_features = self.encode_images(concat_images, patch_positions)
76
+ split_sizes = [image.shape[0] for image in images]
77
+ image_features = torch.split(image_features, split_sizes, dim=0)
78
+ image_features = [x.flatten(0, 1) for x in image_features]
79
+ else:
80
+ image_features = self.encode_images(images, patch_positions) # Sum(Crop Image Number) x L x d
81
+
82
+ new_input_embeds = []
83
+ new_modality_indicators = []
84
+ new_labels = [] if labels is not None else None
85
+ cur_image_idx = 0
86
+ for batch_idx, cur_input_ids in enumerate(input_ids):
87
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
88
+ # multimodal LLM, but the current sample is not multimodal
89
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
90
+ half_len = cur_input_ids.shape[0] // 2
91
+ cur_image_features = image_features[cur_image_idx]
92
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
93
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
94
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
95
+ new_input_embeds.append(cur_input_embeds)
96
+
97
+ cur_modality_indicators = torch.zeros(len(cur_input_embeds)).long().to(self.device)
98
+ new_modality_indicators.append(cur_modality_indicators)
99
+ if labels is not None:
100
+ new_labels.append(labels[batch_idx])
101
+ cur_image_idx += 1
102
+ continue
103
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
104
+ cur_new_input_embeds = []
105
+ cur_modality_indicators = []
106
+ if labels is not None:
107
+ cur_labels = labels[batch_idx]
108
+ cur_new_labels = []
109
+ assert cur_labels.shape == cur_input_ids.shape
110
+ while image_token_indices.numel() > 0:
111
+ cur_image_features = image_features[cur_image_idx]
112
+ image_token_start = image_token_indices[0]
113
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
114
+ cur_new_input_embeds.append(cur_image_features)
115
+
116
+ # Add modality indicator
117
+ assert image_token_start == len(cur_input_ids[:image_token_start])
118
+ cur_modality_indicators.append(torch.zeros(len(cur_input_ids[:image_token_start])).long())
119
+ cur_modality_indicators.append(torch.ones(len(cur_image_features)).long())
120
+
121
+ if labels is not None:
122
+ cur_new_labels.append(cur_labels[:image_token_start])
123
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
124
+ cur_labels = cur_labels[image_token_start+1:]
125
+ cur_image_idx += 1
126
+ cur_input_ids = cur_input_ids[image_token_start+1:]
127
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
128
+ if cur_input_ids.numel() > 0:
129
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
130
+ cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long())
131
+ if labels is not None:
132
+ cur_new_labels.append(cur_labels)
133
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
134
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
135
+ new_input_embeds.append(cur_new_input_embeds)
136
+
137
+ # Modality
138
+ cur_modality_indicators = [x.to(device=self.device) for x in cur_modality_indicators]
139
+ cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0)
140
+ new_modality_indicators.append(cur_modality_indicators)
141
+
142
+
143
+ if labels is not None:
144
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
145
+ new_labels.append(cur_new_labels)
146
+
147
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
148
+ max_len = max(x.shape[0] for x in new_input_embeds)
149
+
150
+ # Embedding
151
+ new_input_embeds_align = []
152
+ for cur_new_embed in new_input_embeds:
153
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
154
+ new_input_embeds_align.append(cur_new_embed)
155
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
156
+
157
+ # Modality
158
+ new_modality_indicators_align = []
159
+ for cur_modality_indicator in new_modality_indicators:
160
+ cur_new_embed = torch.cat((cur_modality_indicator, torch.zeros(max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device)), dim=0)
161
+ new_modality_indicators_align.append(cur_new_embed)
162
+ new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0)
163
+
164
+ # Label
165
+ if labels is not None:
166
+ new_labels_align = []
167
+ _new_labels = new_labels
168
+ for cur_new_label in new_labels:
169
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
170
+ new_labels_align.append(cur_new_label)
171
+ new_labels = torch.stack(new_labels_align, dim=0)
172
+
173
+ # Attention Mask
174
+ if attention_mask is not None:
175
+ new_attention_mask = []
176
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
177
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
178
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
179
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
180
+ new_attention_mask.append(cur_new_attention_mask)
181
+ attention_mask = torch.stack(new_attention_mask, dim=0)
182
+ assert attention_mask.shape == new_labels.shape
183
+ else:
184
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
185
+ new_modality_indicators = torch.stack(new_modality_indicators, dim=0)
186
+ if labels is not None:
187
+ new_labels = torch.stack(new_labels, dim=0)
188
+
189
+ if attention_mask is not None:
190
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
191
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
192
+ assert attention_mask.shape == new_input_embeds.shape[:2]
193
+ return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels
194
+
195
+
196
+
197
+ class MPLUGDocOwlLlamaModel(MPLUGDocOwlMetaModel, LlamaModel):
198
+ config_class = MPLUGDocOwlConfig
199
+
200
+ def __init__(self, config: MPLUGDocOwlConfig):
201
+ super(MPLUGDocOwlLlamaModel, self).__init__(config)
202
+
203
+
204
+ class MPLUGDocOwlLlamaForCausalLM(LlamaForCausalLM, MPLUGDocOwlMetaForCausalLM):
205
+ config_class = MPLUGDocOwlConfig
206
+
207
+ def __init__(self, config):
208
+ super(LlamaForCausalLM, self).__init__(config)
209
+ self.model = MPLUGDocOwlLlamaModel(config)
210
+
211
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
212
+
213
+ # Initialize weights and apply final processing
214
+ self.post_init()
215
+
216
+ def get_model(self):
217
+ return self.model
218
+
219
+ def forward(
220
+ self,
221
+ input_ids: torch.LongTensor = None,
222
+ # modality_indicators: torch.LongTensor = None,
223
+ attention_mask: Optional[torch.Tensor] = None,
224
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
225
+ inputs_embeds: Optional[torch.FloatTensor] = None,
226
+ labels: Optional[torch.LongTensor] = None,
227
+ use_cache: Optional[bool] = None,
228
+ output_attentions: Optional[bool] = None,
229
+ output_hidden_states: Optional[bool] = None,
230
+ images: Optional[torch.FloatTensor] = None,
231
+ patch_positions: Optional[torch.LongTensor] = None,
232
+ return_dict: Optional[bool] = None,
233
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
234
+
235
+ # print('modeling_mplug_docow2.py patch_positions:', patch_positions)
236
+
237
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
238
+ output_hidden_states = (
239
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
240
+ )
241
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
242
+ input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \
243
+ self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, patch_positions)
244
+
245
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
246
+ outputs = self.model(
247
+ input_ids=input_ids,
248
+ modality_indicators=modality_indicators,
249
+ attention_mask=attention_mask,
250
+ past_key_values=past_key_values,
251
+ inputs_embeds=inputs_embeds,
252
+ use_cache=use_cache,
253
+ output_attentions=output_attentions,
254
+ output_hidden_states=output_hidden_states,
255
+ return_dict=return_dict
256
+ )
257
+
258
+ hidden_states = outputs[0]
259
+ logits = self.lm_head(hidden_states)
260
+
261
+ loss = None
262
+ if labels is not None:
263
+ # Shift so that tokens < n predict n
264
+ shift_logits = logits[..., :-1, :].contiguous()
265
+ shift_labels = labels[..., 1:].contiguous()
266
+ # Flatten the tokens
267
+ loss_fct = CrossEntropyLoss()
268
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
269
+ shift_labels = shift_labels.view(-1)
270
+ # Enable model/pipeline parallelism
271
+ shift_labels = shift_labels.to(shift_logits.device)
272
+ loss = loss_fct(shift_logits, shift_labels)
273
+
274
+ if not return_dict:
275
+ output = (logits,) + outputs[1:]
276
+ return (loss,) + output if loss is not None else output
277
+
278
+ return CausalLMOutputWithPast(
279
+ loss=loss,
280
+ logits=logits,
281
+ past_key_values=outputs.past_key_values,
282
+ hidden_states=outputs.hidden_states,
283
+ attentions=outputs.attentions,
284
+ )
285
+
286
+ def prepare_inputs_for_generation(
287
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
288
+ ):
289
+ if past_key_values:
290
+ input_ids = input_ids[:, -1:]
291
+
292
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
293
+ if inputs_embeds is not None and past_key_values is None:
294
+ model_inputs = {"inputs_embeds": inputs_embeds}
295
+ else:
296
+ model_inputs = {"input_ids": input_ids}
297
+
298
+ model_inputs.update(
299
+ {
300
+ "past_key_values": past_key_values,
301
+ "use_cache": kwargs.get("use_cache"),
302
+ "attention_mask": attention_mask,
303
+ "images": kwargs.get("images", None),
304
+ "patch_positions": kwargs.get("patch_positions", None),
305
+ }
306
+ )
307
+ return model_inputs
308
+
309
+ AutoConfig.register("mplug_docowl", MPLUGDocOwlConfig)
310
+ AutoModelForCausalLM.register(MPLUGDocOwlConfig, MPLUGDocOwlLlamaForCausalLM)
311
+
312
+ replace_llama_modality_adaptive()
313
+
mplug_docowl/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type:
7
+ assert cfg.model_type == 'mplug_owl2'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "mplug_owl2")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)
mplug_docowl/model/visual_encoder.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Optional, Tuple, Union
3
+
4
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions
5
+ from transformers.modeling_utils import PreTrainedModel
6
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from icecream import ic
13
+ import einops
14
+ from einops import rearrange
15
+
16
+ def get_abs_pos(abs_pos, tgt_size):
17
+ # abs_pos: L, C
18
+ # tgt_size: M
19
+ # return: M, C
20
+ src_size = int(math.sqrt(abs_pos.size(0)))
21
+ tgt_size = int(math.sqrt(tgt_size))
22
+ dtype = abs_pos.dtype
23
+
24
+ if src_size != tgt_size:
25
+ return F.interpolate(
26
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
27
+ size=(tgt_size, tgt_size),
28
+ mode="bicubic",
29
+ align_corners=False,
30
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
31
+ else:
32
+ return abs_pos
33
+
34
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
35
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
36
+ """
37
+ grid_size: int of the grid height and width
38
+ return:
39
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
40
+ """
41
+ grid_h = np.arange(grid_size, dtype=np.float32)
42
+ grid_w = np.arange(grid_size, dtype=np.float32)
43
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
44
+ grid = np.stack(grid, axis=0)
45
+
46
+ grid = grid.reshape([2, 1, grid_size, grid_size])
47
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
48
+ if cls_token:
49
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
50
+ return pos_embed
51
+
52
+
53
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
54
+ assert embed_dim % 2 == 0
55
+
56
+ # use half of dimensions to encode grid_h
57
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
58
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
59
+
60
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
61
+ return emb
62
+
63
+
64
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
65
+ """
66
+ embed_dim: output dimension for each position
67
+ pos: a list of positions to be encoded: size (M,)
68
+ out: (M, D)
69
+ """
70
+ assert embed_dim % 2 == 0
71
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
72
+ omega /= embed_dim / 2.
73
+ omega = 1. / 10000**omega # (D/2,)
74
+
75
+ pos = pos.reshape(-1) # (M,)
76
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
77
+
78
+ emb_sin = np.sin(out) # (M, D/2)
79
+ emb_cos = np.cos(out) # (M, D/2)
80
+
81
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
82
+ return emb
83
+
84
+
85
+
86
+ class MplugOwlVisionEmbeddings(nn.Module):
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.config = config
90
+ self.hidden_size = config.hidden_size
91
+ self.image_size = config.image_size
92
+ self.patch_size = config.patch_size
93
+
94
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
95
+
96
+ self.patch_embed = nn.Conv2d(
97
+ in_channels=3,
98
+ out_channels=self.hidden_size,
99
+ kernel_size=self.patch_size,
100
+ stride=self.patch_size,
101
+ bias=False,
102
+ )
103
+
104
+ self.num_patches = (self.image_size // self.patch_size) ** 2
105
+
106
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
107
+
108
+ self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
109
+
110
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
111
+ batch_size = pixel_values.size(0)
112
+ image_embeds = self.patch_embed(pixel_values)
113
+ image_embeds = image_embeds.flatten(2).transpose(1, 2)
114
+
115
+ class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
116
+ embeddings = torch.cat([class_embeds, image_embeds], dim=1)
117
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
118
+ embeddings = self.pre_layernorm(embeddings)
119
+ return embeddings
120
+
121
+
122
+
123
+ class MplugOwlVisionAttention(nn.Module):
124
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
125
+
126
+ def __init__(self, config):
127
+ super().__init__()
128
+ self.config = config
129
+ self.hidden_size = config.hidden_size
130
+ self.num_heads = config.num_attention_heads
131
+ self.head_dim = self.hidden_size // self.num_heads
132
+ if self.head_dim * self.num_heads != self.hidden_size:
133
+ raise ValueError(
134
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
135
+ f" {self.num_heads})."
136
+ )
137
+ self.scale = self.head_dim**-0.5
138
+ self.dropout = nn.Dropout(config.attention_dropout)
139
+
140
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size)
141
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
142
+
143
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
144
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
145
+
146
+ def forward(
147
+ self,
148
+ hidden_states: torch.Tensor,
149
+ head_mask: Optional[torch.Tensor] = None,
150
+ output_attentions: Optional[bool] = False,
151
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
152
+ """Input shape: Batch x Time x Channel"""
153
+
154
+ bsz, seq_len, embed_dim = hidden_states.size()
155
+
156
+ mixed_qkv = self.query_key_value(hidden_states)
157
+
158
+ mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute(
159
+ 3, 0, 2, 1, 4
160
+ ) # [3, b, np, sq, hn]
161
+ query_states, key_states, value_states = (
162
+ mixed_qkv[0],
163
+ mixed_qkv[1],
164
+ mixed_qkv[2],
165
+ )
166
+ # if self.config.use_flash_attn and flash_attn_func is not None:
167
+ if False:
168
+ # [b*sq, np, hn]
169
+ query_states = query_states.permute(0, 2, 1, 3).contiguous()
170
+ query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1)
171
+
172
+ key_states = key_states.permute(0, 2, 1, 3).contiguous()
173
+ key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1)
174
+
175
+ value_states = value_states.permute(0, 2, 1, 3).contiguous()
176
+ value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1)
177
+
178
+ cu_seqlens = torch.arange(
179
+ 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device
180
+ )
181
+
182
+ context_layer = flash_attn_func(
183
+ query_states,
184
+ key_states,
185
+ value_states,
186
+ cu_seqlens,
187
+ cu_seqlens,
188
+ seq_len,
189
+ seq_len,
190
+ self.dropout if self.training else 0.0,
191
+ softmax_scale=self.scale,
192
+ causal=False,
193
+ return_attn_probs=False,
194
+ )
195
+ # [b*sq, np, hn] => [b, sq, np, hn]
196
+ context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2))
197
+ else:
198
+ # Take the dot product between "query" and "key" to get the raw attention scores.
199
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
200
+
201
+ attention_scores = attention_scores * self.scale
202
+
203
+ # Normalize the attention scores to probabilities.
204
+ attention_probs = torch.softmax(attention_scores, dim=-1)
205
+
206
+ # This is actually dropping out entire tokens to attend to, which might
207
+ # seem a bit unusual, but is taken from the original Transformer paper.
208
+ attention_probs = self.dropout(attention_probs)
209
+
210
+ # Mask heads if we want to
211
+ if head_mask is not None:
212
+ attention_probs = attention_probs * head_mask
213
+
214
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
215
+
216
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
217
+ context_layer = context_layer.reshape(new_context_layer_shape)
218
+
219
+ output = self.dense(context_layer)
220
+
221
+ outputs = (output, attention_probs) if output_attentions else (output, None)
222
+
223
+ return outputs
224
+
225
+
226
+ class QuickGELU(nn.Module):
227
+ def forward(self, x: torch.Tensor):
228
+ return x * torch.sigmoid(1.702 * x)
229
+
230
+
231
+ class MplugOwlMLP(nn.Module):
232
+ def __init__(self, config):
233
+ super().__init__()
234
+ self.config = config
235
+ self.activation_fn = QuickGELU()
236
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
237
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
238
+
239
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
240
+ hidden_states = self.fc1(hidden_states)
241
+ hidden_states = self.activation_fn(hidden_states)
242
+ hidden_states = self.fc2(hidden_states)
243
+ return hidden_states
244
+
245
+
246
+ class MplugOwlVisionEncoderLayer(nn.Module):
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.hidden_size = config.hidden_size
250
+ self.self_attn = MplugOwlVisionAttention(config)
251
+ self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
252
+ self.mlp = MplugOwlMLP(config)
253
+ self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.Tensor,
258
+ attention_mask: torch.Tensor,
259
+ output_attentions: Optional[bool] = False,
260
+ ) -> Tuple[torch.FloatTensor]:
261
+ """
262
+ Args:
263
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
264
+ attention_mask (`torch.FloatTensor`): attention mask of size
265
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
266
+ `(config.encoder_attention_heads,)`.
267
+ output_attentions (`bool`, *optional*):
268
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
269
+ returned tensors for more detail.
270
+ """
271
+ residual = hidden_states
272
+
273
+ hidden_states = self.input_layernorm(hidden_states)
274
+ hidden_states, attn_weights = self.self_attn(
275
+ hidden_states=hidden_states,
276
+ head_mask=attention_mask,
277
+ output_attentions=output_attentions,
278
+ )
279
+ hidden_states = hidden_states + residual
280
+ residual = hidden_states
281
+ hidden_states = self.post_attention_layernorm(hidden_states)
282
+ hidden_states = self.mlp(hidden_states)
283
+
284
+ hidden_states = hidden_states + residual
285
+
286
+ outputs = (hidden_states,)
287
+
288
+ if output_attentions:
289
+ outputs += (attn_weights,)
290
+
291
+ return outputs
292
+
293
+
294
+ class MplugOwlVisionEncoder(nn.Module):
295
+ """
296
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
297
+ [`MplugOwlVisionEncoderLayer`].
298
+
299
+ Args:
300
+ config (`MplugOwlVisionConfig`):
301
+ The corresponding vision configuration for the `MplugOwlEncoder`.
302
+ """
303
+
304
+ def __init__(self, config):
305
+ super().__init__()
306
+ self.config = config
307
+ self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
308
+ self.gradient_checkpointing = True
309
+
310
+ def forward(
311
+ self,
312
+ inputs_embeds,
313
+ attention_mask: Optional[torch.Tensor] = None,
314
+ output_attentions: Optional[bool] = None,
315
+ output_hidden_states: Optional[bool] = None,
316
+ return_dict: Optional[bool] = None,
317
+ ) -> Union[Tuple, BaseModelOutput]:
318
+ r"""
319
+ Args:
320
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
321
+ Embedded representation of the inputs. Should be float, not int tokens.
322
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
323
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
324
+
325
+ - 1 for tokens that are **not masked**,
326
+ - 0 for tokens that are **masked**.
327
+
328
+ [What are attention masks?](../glossary#attention-mask)
329
+ output_attentions (`bool`, *optional*):
330
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
331
+ returned tensors for more detail.
332
+ output_hidden_states (`bool`, *optional*):
333
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
334
+ for more detail.
335
+ return_dict (`bool`, *optional*):
336
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
337
+ """
338
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
339
+ output_hidden_states = (
340
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
341
+ )
342
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
343
+
344
+ encoder_states = () if output_hidden_states else None
345
+ all_attentions = () if output_attentions else None
346
+
347
+ hidden_states = inputs_embeds
348
+ for idx, encoder_layer in enumerate(self.layers):
349
+ if output_hidden_states:
350
+ encoder_states = encoder_states + (hidden_states,)
351
+ if self.gradient_checkpointing and self.training:
352
+
353
+ def create_custom_forward(module):
354
+ def custom_forward(*inputs):
355
+ return module(*inputs, output_attentions)
356
+
357
+ return custom_forward
358
+
359
+ layer_outputs = torch.utils.checkpoint.checkpoint(
360
+ create_custom_forward(encoder_layer),
361
+ hidden_states,
362
+ attention_mask,
363
+ )
364
+ else:
365
+ layer_outputs = encoder_layer(
366
+ hidden_states,
367
+ attention_mask,
368
+ output_attentions=output_attentions,
369
+ )
370
+
371
+ hidden_states = layer_outputs[0]
372
+
373
+ if output_attentions:
374
+ all_attentions = all_attentions + (layer_outputs[1],)
375
+
376
+ if output_hidden_states:
377
+ encoder_states = encoder_states + (hidden_states,)
378
+
379
+ if not return_dict:
380
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
381
+ return BaseModelOutput(
382
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
383
+ )
384
+
385
+
386
+ class MplugOwlVisionModel(PreTrainedModel):
387
+ main_input_name = "pixel_values"
388
+
389
+ def __init__(self, config):
390
+ super().__init__(config)
391
+ self.config = config
392
+ self.hidden_size = config.hidden_size
393
+
394
+ self.embeddings = MplugOwlVisionEmbeddings(config)
395
+ self.encoder = MplugOwlVisionEncoder(config)
396
+ self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
397
+
398
+ self.post_init()
399
+
400
+
401
+ def forward(
402
+ self,
403
+ pixel_values: Optional[torch.FloatTensor] = None,
404
+ output_attentions: Optional[bool] = None,
405
+ output_hidden_states: Optional[bool] = None,
406
+ return_dict: Optional[bool] = None,
407
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
408
+ r"""
409
+ Returns:
410
+
411
+ """
412
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
+ output_hidden_states = (
414
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415
+ )
416
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
417
+
418
+ if pixel_values is None:
419
+ raise ValueError("You have to specify pixel_values")
420
+
421
+ hidden_states = self.embeddings(pixel_values)
422
+
423
+ encoder_outputs = self.encoder(
424
+ inputs_embeds=hidden_states,
425
+ output_attentions=output_attentions,
426
+ output_hidden_states=output_hidden_states,
427
+ return_dict=return_dict,
428
+ )
429
+
430
+ last_hidden_state = encoder_outputs[0]
431
+ last_hidden_state = self.post_layernorm(last_hidden_state)
432
+
433
+ pooled_output = last_hidden_state[:, 0, :]
434
+ pooled_output = self.post_layernorm(pooled_output)
435
+
436
+ if not return_dict:
437
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
438
+
439
+ return BaseModelOutputWithPooling(
440
+ last_hidden_state=last_hidden_state,
441
+ pooler_output=pooled_output,
442
+ hidden_states=encoder_outputs.hidden_states,
443
+ attentions=encoder_outputs.attentions,
444
+ )
445
+
446
+ def get_input_embeddings(self):
447
+ return self.embeddings
448
+
449
+
450
+ class MplugDocOwlHReducerModel(PreTrainedModel):
451
+ def __init__(self, config, language_hidden_size):
452
+ super().__init__(config)
453
+ self.config = config
454
+ self.ln_q = torch.nn.LayerNorm(self.config.hidden_size, eps=1e-6)
455
+ self.conv_shape = (int(self.config.conv_shape.split('x')[0]), int(self.config.conv_shape.split('x')[1])) #
456
+ self.conv_patch=self.conv_shape[0]*self.conv_shape[1]
457
+ ## feature interaction with a conv layer
458
+ self.reducer_before = torch.nn.Sequential(
459
+ nn.Conv2d(self.config.hidden_size, self.conv_patch*self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True),
460
+ nn.GELU()
461
+ )
462
+ ## reduce visual feature length with a conv layer
463
+ self.reducer = nn.Conv2d(self.config.hidden_size, self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True)
464
+ ## align visual features with language embedding with fc
465
+ self.visual_fc = torch.nn.Linear(self.config.hidden_size, language_hidden_size)
466
+ self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
467
+
468
+ self.post_init()
469
+
470
+ def forward(
471
+ self,
472
+ encoder_hidden_states=None
473
+ ):
474
+ r"""
475
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
476
+ batch_size is the number of all images (global+crop) in a batch
477
+ Sequence of hidden-states at the output of the last layer of the encoder.
478
+ """
479
+ encoder_hidden_states = encoder_hidden_states[:,1:,:] # remove the first cls token
480
+ B, L, C = encoder_hidden_states.shape # B, 1024=(448/14)^2, 1024
481
+
482
+ ## feature interaction with a conv layer
483
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L)))
484
+ hidden_states = self.reducer_before(encoder_hidden_states) # B 4D H W/4
485
+ ## reduce seq length with a conv layer
486
+ """hidden_states = hidden_states.flatten(2).transpose(1, 2) # B 4D H W/4 -> B 4D H*W/4 -> B H*W/4 4D
487
+ hidden_states = rearrange(hidden_states, 'B L (X D) -> B (L X) D', X=self.conv_patch) # B (H W) D
488
+ hidden_states = rearrange(hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L))) # B D H W """
489
+ hidden_states = rearrange(hidden_states, 'B (X D) H W -> B D H (W X)', X=self.conv_patch) # B 4D H W/4 -> B D H W
490
+ sequence_output = self.reducer(hidden_states) # B,C,H,W -> B,C,H/conv_shape[1],W/(conv_shape[1])
491
+ sequence_output = sequence_output.flatten(2).transpose(1, 2) # B,C,H/conv_shape[1],W/(conv_shape[1]) -> B,C,L/conv_patch -> B,L/conv_patch,C
492
+ sequence_output = sequence_output.transpose(0, 1).contiguous() # L/conv_patch, B, C
493
+ ## align visual features with language embedding with fc
494
+ sequence_output = self.visual_fc(sequence_output) # L/conv_patch, B, h
495
+ sequence_output = sequence_output.transpose(0, 1).contiguous() # B, s/4, h
496
+ sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(B, 1, 1)], dim=1)
497
+
498
+ return sequence_output
499
+
mplug_docowl/processor.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image, ImageFile
5
+ import random
6
+ from torchvision.ops.boxes import box_area
7
+
8
+ from torchvision.transforms.transforms import InterpolationMode
9
+ from torchvision.transforms import functional as F
10
+ import numpy as np
11
+ from icecream import ic
12
+
13
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
14
+ ImageFile.MAX_IMAGE_PIXELS = None
15
+ Image.MAX_IMAGE_PIXELS = None
16
+
17
+ def box_iou(boxes1, area1, boxes2, eps=1e-5):
18
+ area2 = box_area(boxes2)
19
+
20
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
21
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
22
+
23
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
24
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
25
+
26
+ union = area1[:, None] + area2 - inter
27
+
28
+ iou = inter / (union+eps)
29
+ return iou, union
30
+
31
+ def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
32
+ # anchors x1 y1 x2 y2
33
+
34
+ # image_size: (h, w)
35
+ # xyxy
36
+ input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
37
+
38
+ boxes1 = anchors
39
+ boxes2 = input_image_bbox
40
+ boxes3 = anchors.clone()
41
+ # y2
42
+ boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
43
+
44
+ area1 = anchors_areas
45
+
46
+ iou, _ = box_iou(boxes1, area1, boxes2)
47
+ iou = iou.squeeze(1)
48
+ shape_iou, _ = box_iou(boxes1, area1, boxes3)
49
+ shape_iou = shape_iou.diag()
50
+ # 优先匹配形状接近 再匹配分辨率接近
51
+ index = torch.argmax(shape_iou*100+iou,dim=0)
52
+ return index
53
+
54
+ class AnchorResize(torch.nn.Module):
55
+
56
+ def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None):
57
+ super().__init__()
58
+ # xyxy
59
+ self.anchors = torch.tensor(
60
+ [[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
61
+ for _ in anchors], requires_grad=False
62
+ )
63
+
64
+ self.anchor_areas = box_area(self.anchors)
65
+
66
+ self.interpolation = interpolation
67
+ self.antialias = antialias
68
+
69
+ def forward(self, img, skip_resize=False):
70
+ """
71
+ Args:
72
+ img (PIL Image or Tensor): Image to be scaled.
73
+
74
+ Returns:
75
+ PIL Image or Tensor: Rescaled image.
76
+ """
77
+ selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
78
+ target_size = self.anchors[selected_anchor][2:].tolist() # w,h
79
+ if skip_resize:
80
+ # for debug
81
+ return selected_anchor
82
+ return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
83
+
84
+ def __repr__(self) -> str:
85
+ detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
86
+ return f"{self.__class__.__name__}{detail}"
87
+
88
+ grid_dict = {
89
+ 'grid_1':[
90
+ (1,1)],
91
+ 'grid_4':[
92
+ (1,1),
93
+ (1,2),(2,1),
94
+ (1,3),(3,1),
95
+ (2,2),(1,4),(4,1)],
96
+ 'grid_9':[
97
+ (1,1),
98
+ (1,2),(2,1),
99
+ (1,3),(3,1),
100
+ (2,2),(1,4),(4,1),
101
+ (1,5),(5,1),
102
+ (1,6),(6,1),(2,3),(3,2),
103
+ (1,7),(7,1),
104
+ (4,2),(2,4),(1,8),(8,1),
105
+ (3,3),(1,9),(9,1)],
106
+ 'grid_3x3':[
107
+ (3,3)],
108
+ 'grid_20':[
109
+ (1, 1),
110
+ (1, 2), (2, 1),
111
+ (1, 3), (3, 1), (1, 4), (2, 2), (4, 1),
112
+ (1, 5), (5, 1),
113
+ (1, 6), (2, 3), (3, 2), (6, 1),
114
+ (1, 7), (7, 1),
115
+ (1, 8), (2, 4), (4, 2), (8, 1),
116
+ (1, 9), (3, 3), (9, 1),
117
+ (1, 10), (2, 5), (5, 2), (10, 1),
118
+ (1, 11), (11, 1),
119
+ (2, 6), (3, 4), (4, 3), (6, 2),
120
+ (2, 7), (7, 2),
121
+ (3, 5), (5, 3),
122
+ (2, 8), (4, 4), (8, 2),
123
+ (2, 9), (3, 6), (6, 3), (9, 2),
124
+ (2, 10), (4, 5), (5, 4), (10, 2)]
125
+ }
126
+
127
+ class DocProcessor():
128
+ def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False):
129
+ self.add_global_img = add_global_img
130
+ self.add_textual_crop_indicator = add_textual_crop_indicator
131
+ self.media_token= "<|image|>"
132
+ # h,w
133
+ if isinstance(image_size, int):
134
+ image_size = (image_size, image_size)
135
+ self.image_size = image_size
136
+ # h,w
137
+ anchors = grid_dict[anchors]
138
+ self.anchors = [tuple(_) for _ in anchors]
139
+ self.anchor_max = max([max(_) for _ in self.anchors])
140
+ # xywh -> xyxy
141
+ self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC)
142
+ self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
143
+ self.image_transform = transforms.Compose([
144
+ transforms.ToTensor(),
145
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
146
+ ])
147
+
148
+ def _process_image(self, images):
149
+ new_images = []
150
+ new_patch_position = []
151
+ num_image_mult = []
152
+ for image in images:
153
+ if self.add_global_img:
154
+ nocut_image = self.image_transform(self.old_resizer(image)).unsqueeze(0)
155
+
156
+ image, selected_anchor = self.resizer(image)
157
+ image_input = self.image_transform(image) # h,w,3 -> 3,h,w
158
+ # rearrange(x,'B C (n1 h) (n2 w) -> (B n1 n2) C h w', n1=self.down_sample[0], n2=self.down_sample[1])
159
+ image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
160
+
161
+ if self.add_global_img:
162
+ image_input = torch.cat([nocut_image, image_input], dim=0)
163
+
164
+ anchor = self.anchors[selected_anchor] # w,h
165
+ ic(anchor)
166
+ patch_position = torch.cat([
167
+ repeat(torch.arange(anchor[0]), 'num_h -> num_h num_w 1', num_w=anchor[1]),
168
+ repeat(torch.arange(anchor[1]), 'num_w -> num_h num_w 1', num_h=anchor[0])],dim=2)
169
+ patch_position = rearrange(patch_position, 'num_h num_w p-> (num_h num_w) p', p=2) # num_patch, (ph,pw)
170
+
171
+ if self.add_global_img:
172
+ patch_position = torch.cat([torch.ones(1,2).long()*self.anchor_max, patch_position], dim=0)
173
+
174
+ new_images.append(image_input)
175
+ new_patch_position.append(patch_position)
176
+ num_image_mult.append(patch_position.shape[0])
177
+
178
+ new_images = torch.cat(new_images,dim=0)
179
+ new_patch_position = torch.cat(new_patch_position, dim=0)
180
+ return new_images, new_patch_position, num_image_mult
181
+
182
+ def __call__(self, images=None, query=None):
183
+ assert images is not None
184
+
185
+ if not isinstance(images, list):
186
+ images = [images]
187
+ image_pils = []
188
+ for image in images:
189
+ if isinstance(image, str):
190
+ image = Image.open(image).convert('RGB')
191
+ else:
192
+ image = image.convert('RGB')
193
+ # ic(image.size)
194
+ image_pils.append(image)
195
+
196
+ image_data, patch_position, num_image_mult = self._process_image(image_pils)
197
+
198
+ assert self.media_token in query
199
+ text_list = query.split(self.media_token)
200
+ text = text_list[0]
201
+ image_token_ptr = 0
202
+ for next_text in text_list[1:]:
203
+ if self.add_textual_crop_indicator:
204
+ # generate image placeholders with interleaved texutual crop indicator
205
+ # e.g. <global_img><|image|><crop_img_row0_col0><|image|><crop_img_row0_col1><|image|>...
206
+ for patch_pos in patch_position.tolist():
207
+ # global non-crop image
208
+ if patch_pos[0] == self.anchor_max and patch_pos[1] == self.anchor_max:
209
+ text += '<global_img><|image|>'
210
+ else:
211
+ row_col = 'row'+str(patch_pos[0])+'_col'+str(patch_pos[1])
212
+ text += '<crop_img_'+row_col+'><|image|>'
213
+ else:
214
+ # generate successive image placeholders for a image, 1 crop img == 1 <|image|>
215
+ text += '<|image|>'*num_image_mult[image_token_ptr]
216
+ text += next_text
217
+ image_token_ptr += 1
218
+
219
+ return image_data, patch_position, text
mplug_docowl/serve/__init__.py ADDED
File without changes
mplug_docowl/serve/cli.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from mplug_owl2.conversation import conv_templates, SeparatorStyle
6
+ from mplug_owl2.model.builder import load_pretrained_model
7
+ from mplug_owl2.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+
17
+ def disable_torch_init():
18
+ """
19
+ Disable the redundant torch default initialization to accelerate model creation.
20
+ """
21
+ import torch
22
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24
+
25
+
26
+ def load_image(image_file):
27
+ if image_file.startswith('http://') or image_file.startswith('https://'):
28
+ response = requests.get(image_file)
29
+ image = Image.open(BytesIO(response.content)).convert('RGB')
30
+ else:
31
+ image = Image.open(image_file).convert('RGB')
32
+ return image
33
+
34
+
35
+ def main(args):
36
+ # Model
37
+ disable_torch_init()
38
+
39
+ model_name = get_model_name_from_path(args.model_path)
40
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
41
+
42
+ conv_mode = "mplug_owl2"
43
+
44
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
45
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
46
+ else:
47
+ args.conv_mode = conv_mode
48
+
49
+ conv = conv_templates[args.conv_mode].copy()
50
+ roles = conv.roles
51
+
52
+ image = load_image(args.image_file)
53
+ # Similar operation in model_worker.py
54
+ image_tensor = process_images([image], image_processor, args)
55
+ if type(image_tensor) is list:
56
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
57
+ else:
58
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
59
+
60
+ while True:
61
+ try:
62
+ inp = input(f"{roles[0]}: ")
63
+ except EOFError:
64
+ inp = ""
65
+ if not inp:
66
+ print("exit...")
67
+ break
68
+
69
+ print(f"{roles[1]}: ", end="")
70
+
71
+ if image is not None:
72
+ # first message
73
+ inp = DEFAULT_IMAGE_TOKEN + inp
74
+ conv.append_message(conv.roles[0], inp)
75
+ image = None
76
+ else:
77
+ # later messages
78
+ conv.append_message(conv.roles[0], inp)
79
+ conv.append_message(conv.roles[1], None)
80
+ prompt = conv.get_prompt()
81
+
82
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
83
+ stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
84
+ keywords = [stop_str]
85
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
86
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
87
+
88
+ with torch.inference_mode():
89
+ output_ids = model.generate(
90
+ input_ids,
91
+ images=image_tensor,
92
+ do_sample=True,
93
+ temperature=args.temperature,
94
+ max_new_tokens=args.max_new_tokens,
95
+ streamer=streamer,
96
+ use_cache=True,
97
+ stopping_criteria=[stopping_criteria])
98
+
99
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
100
+ conv.messages[-1][-1] = outputs
101
+
102
+ if args.debug:
103
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
104
+
105
+
106
+ if __name__ == "__main__":
107
+ parser = argparse.ArgumentParser()
108
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
109
+ parser.add_argument("--model-base", type=str, default=None)
110
+ parser.add_argument("--image-file", type=str, required=True)
111
+ parser.add_argument("--device", type=str, default="cuda")
112
+ parser.add_argument("--conv-mode", type=str, default=None)
113
+ parser.add_argument("--temperature", type=float, default=0.2)
114
+ parser.add_argument("--max-new-tokens", type=int, default=512)
115
+ parser.add_argument("--load-8bit", action="store_true")
116
+ parser.add_argument("--load-4bit", action="store_true")
117
+ parser.add_argument("--debug", action="store_true")
118
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
119
+ args = parser.parse_args()
120
+ main(args)
mplug_docowl/serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from mplug_owl2.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from mplug_owl2.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,))
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
mplug_docowl/serve/examples/Rebecca_(1939_poster)_Small.jpeg ADDED
mplug_docowl/serve/examples/extreme_ironing.jpg ADDED
mplug_docowl/serve/gradio_web_server.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from mplug_owl2.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from mplug_owl2.constants import LOGDIR
13
+ from mplug_owl2.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "mPLUG-Owl2 Client"}
21
+
22
+ no_change_btn = gr.Button.update()
23
+ enable_btn = gr.Button.update(interactive=True)
24
+ disable_btn = gr.Button.update(interactive=False)
25
+
26
+ priority = {
27
+ "vicuna-13b": "aaaaaaa",
28
+ "koala-13b": "aaaaaab",
29
+ }
30
+
31
+
32
+ def get_conv_log_filename():
33
+ t = datetime.datetime.now()
34
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
+ return name
36
+
37
+
38
+ def get_model_list():
39
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
40
+ assert ret.status_code == 200
41
+ ret = requests.post(args.controller_url + "/list_models")
42
+ models = ret.json()["models"]
43
+ models.sort(key=lambda x: priority.get(x, x))
44
+ logger.info(f"Models: {models}")
45
+ return models
46
+
47
+
48
+ get_window_url_params = """
49
+ function() {
50
+ const params = new URLSearchParams(window.location.search);
51
+ url_params = Object.fromEntries(params);
52
+ console.log(url_params);
53
+ return url_params;
54
+ }
55
+ """
56
+
57
+
58
+ def load_demo(url_params, request: gr.Request):
59
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
+
61
+ dropdown_update = gr.Dropdown.update(visible=True)
62
+ if "model" in url_params:
63
+ model = url_params["model"]
64
+ if model in models:
65
+ dropdown_update = gr.Dropdown.update(
66
+ value=model, visible=True)
67
+
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+
72
+ def load_demo_refresh_model_list(request: gr.Request):
73
+ logger.info(f"load_demo. ip: {request.client.host}")
74
+ models = get_model_list()
75
+ state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown.update(
77
+ choices=models,
78
+ value=models[0] if len(models) > 0 else ""
79
+ )
80
+ return state, dropdown_update
81
+
82
+
83
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
+ with open(get_conv_log_filename(), "a") as fout:
85
+ data = {
86
+ "tstamp": round(time.time(), 4),
87
+ "type": vote_type,
88
+ "model": model_selector,
89
+ "state": state.dict(),
90
+ "ip": request.client.host,
91
+ }
92
+ fout.write(json.dumps(data) + "\n")
93
+
94
+
95
+ def upvote_last_response(state, model_selector, request: gr.Request):
96
+ logger.info(f"upvote. ip: {request.client.host}")
97
+ vote_last_response(state, "upvote", model_selector, request)
98
+ return ("",) + (disable_btn,) * 3
99
+
100
+
101
+ def downvote_last_response(state, model_selector, request: gr.Request):
102
+ logger.info(f"downvote. ip: {request.client.host}")
103
+ vote_last_response(state, "downvote", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ def flag_last_response(state, model_selector, request: gr.Request):
108
+ logger.info(f"flag. ip: {request.client.host}")
109
+ vote_last_response(state, "flag", model_selector, request)
110
+ return ("",) + (disable_btn,) * 3
111
+
112
+
113
+ def regenerate(state, image_process_mode, request: gr.Request):
114
+ logger.info(f"regenerate. ip: {request.client.host}")
115
+ state.messages[-1][-1] = None
116
+ prev_human_msg = state.messages[-2]
117
+ if type(prev_human_msg[1]) in (tuple, list):
118
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
+ state.skip_next = False
120
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
+
122
+
123
+ def clear_history(request: gr.Request):
124
+ logger.info(f"clear_history. ip: {request.client.host}")
125
+ state = default_conversation.copy()
126
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
+
128
+
129
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
130
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
+ if len(text) <= 0 and image is None:
132
+ state.skip_next = True
133
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
+ if args.moderate:
135
+ flagged = violates_moderation(text)
136
+ if flagged:
137
+ state.skip_next = True
138
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
+ no_change_btn,) * 5
140
+
141
+ text = text[:1536] # Hard cut-off
142
+ if image is not None:
143
+ text = text[:1200] # Hard cut-off for images
144
+ if '<|image|>' not in text:
145
+ # text = text + '<|image|>'
146
+ text = '<|image|>' + text
147
+ text = (text, image, image_process_mode)
148
+ if len(state.get_images(return_pil=True)) > 0:
149
+ state = default_conversation.copy()
150
+ state.append_message(state.roles[0], text)
151
+ state.append_message(state.roles[1], None)
152
+ state.skip_next = False
153
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
+
155
+
156
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
+ return
165
+
166
+ if len(state.messages) == state.offset + 2:
167
+ # First round of conversation
168
+ template_name = "mplug_owl2"
169
+ new_state = conv_templates[template_name].copy()
170
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
171
+ new_state.append_message(new_state.roles[1], None)
172
+ state = new_state
173
+
174
+ # Query worker address
175
+ controller_url = args.controller_url
176
+ ret = requests.post(controller_url + "/get_worker_address",
177
+ json={"model": model_name})
178
+ worker_addr = ret.json()["address"]
179
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
180
+
181
+ # No available worker
182
+ if worker_addr == "":
183
+ state.messages[-1][-1] = server_error_msg
184
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
185
+ return
186
+
187
+ # Construct prompt
188
+ prompt = state.get_prompt()
189
+
190
+ all_images = state.get_images(return_pil=True)
191
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
192
+ for image, hash in zip(all_images, all_image_hash):
193
+ t = datetime.datetime.now()
194
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
195
+ if not os.path.isfile(filename):
196
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
197
+ image.save(filename)
198
+
199
+ # Make requests
200
+ pload = {
201
+ "model": model_name,
202
+ "prompt": prompt,
203
+ "temperature": float(temperature),
204
+ "top_p": float(top_p),
205
+ "max_new_tokens": min(int(max_new_tokens), 1536),
206
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
207
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
208
+ }
209
+ logger.info(f"==== request ====\n{pload}")
210
+
211
+ pload['images'] = state.get_images()
212
+
213
+ state.messages[-1][-1] = "▌"
214
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
215
+
216
+ try:
217
+ # Stream output
218
+ response = requests.post(worker_addr + "/worker_generate_stream",
219
+ headers=headers, json=pload, stream=True, timeout=10)
220
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
221
+ if chunk:
222
+ data = json.loads(chunk.decode())
223
+ if data["error_code"] == 0:
224
+ output = data["text"][len(prompt):].strip()
225
+ state.messages[-1][-1] = output + "▌"
226
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
227
+ else:
228
+ output = data["text"] + f" (error_code: {data['error_code']})"
229
+ state.messages[-1][-1] = output
230
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
231
+ return
232
+ time.sleep(0.03)
233
+ except requests.exceptions.RequestException as e:
234
+ state.messages[-1][-1] = server_error_msg
235
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
236
+ return
237
+
238
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
239
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
240
+
241
+ finish_tstamp = time.time()
242
+ logger.info(f"{output}")
243
+
244
+ with open(get_conv_log_filename(), "a") as fout:
245
+ data = {
246
+ "tstamp": round(finish_tstamp, 4),
247
+ "type": "chat",
248
+ "model": model_name,
249
+ "start": round(start_tstamp, 4),
250
+ "finish": round(start_tstamp, 4),
251
+ "state": state.dict(),
252
+ "images": all_image_hash,
253
+ "ip": request.client.host,
254
+ }
255
+ fout.write(json.dumps(data) + "\n")
256
+
257
+
258
+ title_markdown = ("""
259
+ <h1 align="center"><a href="https://github.com/X-PLUG/mPLUG-Owl"><img src="https://z1.ax1x.com/2023/11/03/piM1rGQ.md.png", alt="mPLUG-Owl" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
260
+
261
+ <h2 align="center"> mPLUG-Owl2: Revolutionizing Multi-modal Large Language Model with Modality Collaboration</h2>
262
+
263
+ <h5 align="center"> If you like our project, please give us a star ✨ on Github for latest update. </h2>
264
+
265
+ <div align="center">
266
+ <div style="display:flex; gap: 0.25rem;" align="center">
267
+ <a href='https://github.com/X-PLUG/mPLUG-Owl'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
268
+ <a href="https://arxiv.org/abs/2304.14178"><img src="https://img.shields.io/badge/Arxiv-2304.14178-red"></a>
269
+ <a href='https://github.com/X-PLUG/mPLUG-Owl/stargazers'><img src='https://img.shields.io/github/stars/X-PLUG/mPLUG-Owl.svg?style=social'></a>
270
+ </div>
271
+ </div>
272
+
273
+ """)
274
+
275
+
276
+ tos_markdown = ("""
277
+ ### Terms of use
278
+ By using this service, users are required to agree to the following terms:
279
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
280
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
281
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
282
+ """)
283
+
284
+
285
+ learn_more_markdown = ("""
286
+ ### License
287
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
288
+ """)
289
+
290
+ block_css = """
291
+
292
+ #buttons button {
293
+ min-width: min(120px,100%);
294
+ }
295
+
296
+ """
297
+
298
+ def build_demo(embed_mode):
299
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
300
+ with gr.Blocks(title="mPLUG-Owl2", theme=gr.themes.Default(), css=block_css) as demo:
301
+ state = gr.State()
302
+
303
+ if not embed_mode:
304
+ gr.Markdown(title_markdown)
305
+
306
+ with gr.Row():
307
+ with gr.Column(scale=3):
308
+ with gr.Row(elem_id="model_selector_row"):
309
+ model_selector = gr.Dropdown(
310
+ choices=models,
311
+ value=models[0] if len(models) > 0 else "",
312
+ interactive=True,
313
+ show_label=False,
314
+ container=False)
315
+
316
+ imagebox = gr.Image(type="pil")
317
+ image_process_mode = gr.Radio(
318
+ ["Crop", "Resize", "Pad", "Default"],
319
+ value="Default",
320
+ label="Preprocess for non-square image", visible=False)
321
+
322
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
323
+ gr.Examples(examples=[
324
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
325
+ [f"{cur_dir}/examples/Rebecca_(1939_poster)_Small.jpeg", "What is the name of the movie in the poster?"],
326
+ ], inputs=[imagebox, textbox])
327
+
328
+ with gr.Accordion("Parameters", open=True) as parameter_row:
329
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
330
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
331
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
332
+
333
+ with gr.Column(scale=8):
334
+ chatbot = gr.Chatbot(elem_id="Chatbot", label="mPLUG-Owl2 Chatbot", height=600)
335
+ with gr.Row():
336
+ with gr.Column(scale=8):
337
+ textbox.render()
338
+ with gr.Column(scale=1, min_width=50):
339
+ submit_btn = gr.Button(value="Send", variant="primary")
340
+ with gr.Row(elem_id="buttons") as button_row:
341
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
342
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
343
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
344
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
345
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
346
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
347
+
348
+ if not embed_mode:
349
+ gr.Markdown(tos_markdown)
350
+ gr.Markdown(learn_more_markdown)
351
+ url_params = gr.JSON(visible=False)
352
+
353
+ # Register listeners
354
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
355
+ upvote_btn.click(
356
+ upvote_last_response,
357
+ [state, model_selector],
358
+ [textbox, upvote_btn, downvote_btn, flag_btn],
359
+ queue=False
360
+ )
361
+ downvote_btn.click(
362
+ downvote_last_response,
363
+ [state, model_selector],
364
+ [textbox, upvote_btn, downvote_btn, flag_btn],
365
+ queue=False
366
+ )
367
+ flag_btn.click(
368
+ flag_last_response,
369
+ [state, model_selector],
370
+ [textbox, upvote_btn, downvote_btn, flag_btn],
371
+ queue=False
372
+ )
373
+
374
+ regenerate_btn.click(
375
+ regenerate,
376
+ [state, image_process_mode],
377
+ [state, chatbot, textbox, imagebox] + btn_list,
378
+ queue=False
379
+ ).then(
380
+ http_bot,
381
+ [state, model_selector, temperature, top_p, max_output_tokens],
382
+ [state, chatbot] + btn_list
383
+ )
384
+
385
+ clear_btn.click(
386
+ clear_history,
387
+ None,
388
+ [state, chatbot, textbox, imagebox] + btn_list,
389
+ queue=False
390
+ )
391
+
392
+ textbox.submit(
393
+ add_text,
394
+ [state, textbox, imagebox, image_process_mode],
395
+ [state, chatbot, textbox, imagebox] + btn_list,
396
+ queue=False
397
+ ).then(
398
+ http_bot,
399
+ [state, model_selector, temperature, top_p, max_output_tokens],
400
+ [state, chatbot] + btn_list
401
+ )
402
+
403
+ submit_btn.click(
404
+ add_text,
405
+ [state, textbox, imagebox, image_process_mode],
406
+ [state, chatbot, textbox, imagebox] + btn_list,
407
+ queue=False
408
+ ).then(
409
+ http_bot,
410
+ [state, model_selector, temperature, top_p, max_output_tokens],
411
+ [state, chatbot] + btn_list
412
+ )
413
+
414
+ if args.model_list_mode == "once":
415
+ demo.load(
416
+ load_demo,
417
+ [url_params],
418
+ [state, model_selector],
419
+ _js=get_window_url_params,
420
+ queue=False
421
+ )
422
+ elif args.model_list_mode == "reload":
423
+ demo.load(
424
+ load_demo_refresh_model_list,
425
+ None,
426
+ [state, model_selector],
427
+ queue=False
428
+ )
429
+ else:
430
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
431
+
432
+ return demo
433
+
434
+
435
+ if __name__ == "__main__":
436
+ parser = argparse.ArgumentParser()
437
+ parser.add_argument("--host", type=str, default="0.0.0.0")
438
+ parser.add_argument("--port", type=int)
439
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
440
+ parser.add_argument("--concurrency-count", type=int, default=10)
441
+ parser.add_argument("--model-list-mode", type=str, default="once",
442
+ choices=["once", "reload"])
443
+ parser.add_argument("--share", action="store_true")
444
+ parser.add_argument("--moderate", action="store_true")
445
+ parser.add_argument("--embed", action="store_true")
446
+ args = parser.parse_args()
447
+ logger.info(f"args: {args}")
448
+
449
+ models = get_model_list()
450
+
451
+ logger.info(args)
452
+ demo = build_demo(args.embed)
453
+ demo.queue(
454
+ concurrency_count=args.concurrency_count,
455
+ api_open=False
456
+ ).launch(
457
+ server_name=args.host,
458
+ server_port=args.port,
459
+ share=False
460
+ )
mplug_docowl/serve/model_worker.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+
18
+ from mplug_docowl.utils import (build_logger, server_error_msg,
19
+ pretty_print_semaphore)
20
+
21
+ from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,WORKER_HEART_BEAT_INTERVAL
22
+ from mplug_docowl.conversation import conv_templates, SeparatorStyle
23
+ from mplug_docowl.model.builder import load_pretrained_model
24
+ from mplug_docowl.mm_utils import load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
25
+ from mplug_docowl.processor import DocProcessor
26
+
27
+
28
+ from transformers import TextIteratorStreamer
29
+ from threading import Thread
30
+
31
+
32
+ GB = 1 << 30
33
+
34
+ worker_id = str(uuid.uuid4())[:6]
35
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
36
+ global_counter = 0
37
+
38
+ model_semaphore = None
39
+
40
+
41
+ def heart_beat_worker(controller):
42
+
43
+ while True:
44
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
45
+ controller.send_heart_beat()
46
+
47
+
48
+ class DocOwlInfer():
49
+ def __init__(self, ckpt_path, anchors='grid_9', add_global_img=True, load_8bit=False, load_4bit=False):
50
+ model_name = get_model_name_from_path(ckpt_path)
51
+ ic(model_name)
52
+ self.tokenizer, self.model, _, _ = load_pretrained_model(ckpt_path, None, model_name, load_8bit=load_8bit, load_4bit=load_4bit, device="cuda")
53
+ self.doc_image_processor = DocProcessor(image_size=448, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
54
+ self.streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
55
+
56
+ def inference(self, image, query):
57
+ image_tensor, patch_positions, text = self.doc_image_processor(images=image, query='<|image|>'+query)
58
+ image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
59
+ patch_positions = patch_positions.to(self.model.device)
60
+
61
+ # ic(image_tensor.shape, patch_positions.shape, text)
62
+
63
+ conv = conv_templates["mplug_owl2"].copy()
64
+ roles = conv.roles # ("USER", "ASSISTANT")
65
+
66
+ conv.append_message(conv.roles[0], text)
67
+ conv.append_message(conv.roles[1], None)
68
+ prompt = conv.get_prompt()
69
+
70
+ # ic(prompt)
71
+
72
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
73
+
74
+ # ic(input_ids)
75
+
76
+ stop_str = conv.sep2
77
+ keywords = [stop_str]
78
+ stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
79
+
80
+ with torch.inference_mode():
81
+ output_ids = self.model.generate(
82
+ input_ids,
83
+ images=image_tensor,
84
+ patch_positions=patch_positions,
85
+ do_sample=False,
86
+ temperature=1.0,
87
+ max_new_tokens=512,
88
+ streamer=self.streamer,
89
+ use_cache=True,
90
+ stopping_criteria=[stopping_criteria])
91
+
92
+ outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
93
+
94
+ return outputs.replace('</s>', '')
95
+
96
+ # TODO: adapt for docowl infer
97
+ class ModelWorker:
98
+ def __init__(self, controller_addr, worker_addr,
99
+ worker_id, no_register,
100
+ model_path, model_base, model_name,
101
+ resolution, anchors, add_global_img,
102
+ load_8bit, load_4bit, device):
103
+ self.controller_addr = controller_addr
104
+ self.worker_addr = worker_addr
105
+ self.worker_id = worker_id
106
+ if model_path.endswith("/"):
107
+ model_path = model_path[:-1]
108
+
109
+ self.model_name = get_model_name_from_path(ckpt_path)
110
+
111
+ self.device = device
112
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
113
+
114
+ self.tokenizer, self.model, _, self.context_len = load_pretrained_model(
115
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
116
+
117
+ self.resolution=resolution
118
+ self.token_num_each_img = (self.resolution/14)*(self.resolution/14)/self.model.get_model().vison2text.conv_patch
119
+ self.doc_image_processor = DocProcessor(image_size=resolution, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
120
+
121
+
122
+ self.is_multimodal = True
123
+
124
+ if not no_register:
125
+ self.register_to_controller()
126
+ self.heart_beat_thread = threading.Thread(
127
+ target=heart_beat_worker, args=(self,))
128
+ self.heart_beat_thread.start()
129
+
130
+ def register_to_controller(self):
131
+ logger.info("Register to controller")
132
+
133
+ url = self.controller_addr + "/register_worker"
134
+ data = {
135
+ "worker_name": self.worker_addr,
136
+ "check_heart_beat": True,
137
+ "worker_status": self.get_status()
138
+ }
139
+ r = requests.post(url, json=data)
140
+ assert r.status_code == 200
141
+
142
+ def send_heart_beat(self):
143
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
144
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
145
+ f"global_counter: {global_counter}")
146
+
147
+ url = self.controller_addr + "/receive_heart_beat"
148
+
149
+ while True:
150
+ try:
151
+ ret = requests.post(url, json={
152
+ "worker_name": self.worker_addr,
153
+ "queue_length": self.get_queue_length()}, timeout=5)
154
+ exist = ret.json()["exist"]
155
+ break
156
+ except requests.exceptions.RequestException as e:
157
+ logger.error(f"heart beat error: {e}")
158
+ time.sleep(5)
159
+
160
+ if not exist:
161
+ self.register_to_controller()
162
+
163
+ def get_queue_length(self):
164
+ if model_semaphore is None:
165
+ return 0
166
+ else:
167
+ return args.limit_model_concurrency - model_semaphore._value + (len(
168
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
169
+
170
+ def get_status(self):
171
+ return {
172
+ "model_names": [self.model_name],
173
+ "speed": 1,
174
+ "queue_length": self.get_queue_length(),
175
+ }
176
+
177
+ @torch.inference_mode()
178
+ def generate_stream(self, params):
179
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
180
+
181
+ prompt = params["prompt"]
182
+ ori_prompt = prompt
183
+ images = params.get("images", None)
184
+ num_image_tokens = 0
185
+ if images is not None and len(images) > 0 and self.is_multimodal:
186
+ if len(images) > 0:
187
+
188
+ """if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
189
+ raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
190
+
191
+ images = [load_image_from_base64(image) for image in images]
192
+ images = process_images(images, image_processor, model.config)
193
+
194
+ if type(images) is list:
195
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
196
+ else:
197
+ images = images.to(self.model.device, dtype=torch.float16)"""
198
+
199
+ # docowl only support 1 image, so only keep the last image
200
+ image = images[-1]
201
+ assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
202
+
203
+ image_tensor, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
204
+ image_tensor = image_tensor.to(self.model.device, dtype=torch.float16)
205
+ patch_positions = patch_positions.to(self.model.device)
206
+
207
+ replace_token = DEFAULT_IMAGE_TOKEN
208
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
209
+ num_image_tokens = prompt.count(replace_token) * (self.token_num_each_img+1)
210
+ else:
211
+ images = None
212
+ patch_positions = None
213
+ image_args = {"images": images, "patch_positions":patch_positions}
214
+ else:
215
+ images = None
216
+ image_args = {}
217
+
218
+ temperature = float(params.get("temperature", 1.0))
219
+ top_p = float(params.get("top_p", 1.0))
220
+ max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
221
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
222
+ stop_str = params.get("stop", None)
223
+ do_sample = True if temperature > 0.001 else False
224
+
225
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
226
+ keywords = [stop_str]
227
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
228
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
229
+
230
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
231
+
232
+ if max_new_tokens < 1:
233
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
234
+ return
235
+
236
+ thread = Thread(target=model.generate, kwargs=dict(
237
+ inputs=input_ids,
238
+ do_sample=do_sample,
239
+ temperature=temperature,
240
+ top_p=top_p,
241
+ max_new_tokens=max_new_tokens,
242
+ streamer=streamer,
243
+ stopping_criteria=[stopping_criteria],
244
+ use_cache=True,
245
+ **image_args
246
+ ))
247
+ thread.start()
248
+
249
+ generated_text = ori_prompt
250
+ for new_text in streamer:
251
+ generated_text += new_text
252
+ if generated_text.endswith(stop_str):
253
+ generated_text = generated_text[:-len(stop_str)]
254
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
255
+
256
+ def generate_stream_gate(self, params):
257
+ try:
258
+ for x in self.generate_stream(params):
259
+ yield x
260
+ except ValueError as e:
261
+ print("Caught ValueError:", e)
262
+ ret = {
263
+ "text": server_error_msg,
264
+ "error_code": 1,
265
+ }
266
+ yield json.dumps(ret).encode() + b"\0"
267
+ except torch.cuda.CudaError as e:
268
+ print("Caught torch.cuda.CudaError:", e)
269
+ ret = {
270
+ "text": server_error_msg,
271
+ "error_code": 1,
272
+ }
273
+ yield json.dumps(ret).encode() + b"\0"
274
+ except Exception as e:
275
+ print("Caught Unknown Error", e)
276
+ ret = {
277
+ "text": server_error_msg,
278
+ "error_code": 1,
279
+ }
280
+ yield json.dumps(ret).encode() + b"\0"
281
+
282
+ app = FastAPI()
283
+
284
+ def release_model_semaphore(fn=None):
285
+ model_semaphore.release()
286
+ if fn is not None:
287
+ fn()
288
+
289
+
290
+ @app.post("/worker_generate_stream")
291
+ async def generate_stream(request: Request):
292
+ global model_semaphore, global_counter
293
+ global_counter += 1
294
+ params = await request.json()
295
+
296
+ if model_semaphore is None:
297
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
298
+ await model_semaphore.acquire()
299
+ worker.send_heart_beat()
300
+ generator = worker.generate_stream_gate(params)
301
+ background_tasks = BackgroundTasks()
302
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
303
+ return StreamingResponse(generator, background=background_tasks)
304
+
305
+
306
+ @app.post("/worker_get_status")
307
+ async def get_status(request: Request):
308
+ return worker.get_status()
309
+
310
+
311
+ if __name__ == "__main__":
312
+ parser = argparse.ArgumentParser()
313
+ parser.add_argument("--host", type=str, default="localhost")
314
+ parser.add_argument("--port", type=int, default=21002)
315
+ parser.add_argument("--worker-address", type=str,
316
+ default="http://localhost:21002")
317
+ parser.add_argument("--controller-address", type=str,
318
+ default="http://localhost:21001")
319
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
320
+ parser.add_argument("--model-base", type=str, default=None)
321
+ parser.add_argument("--model-name", type=str)
322
+ parser.add_argument("--device", type=str, default="cuda")
323
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
324
+ parser.add_argument("--stream-interval", type=int, default=1)
325
+ parser.add_argument("--no-register", action="store_true")
326
+ parser.add_argument("--load-8bit", action="store_true")
327
+ parser.add_argument("--load-4bit", action="store_true")
328
+ args = parser.parse_args()
329
+ logger.info(f"args: {args}")
330
+
331
+
332
+ worker = ModelWorker(args.controller_address,
333
+ args.worker_address,
334
+ worker_id,
335
+ args.no_register,
336
+ args.model_path,
337
+ args.model_base,
338
+ args.model_name,
339
+ args.load_8bit,
340
+ args.load_4bit,
341
+ args.device)
342
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
mplug_docowl/serve/model_worker_bak.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+
18
+ from mplug_owl2.constants import WORKER_HEART_BEAT_INTERVAL
19
+ from mplug_owl2.utils import (build_logger, server_error_msg,
20
+ pretty_print_semaphore)
21
+ from mplug_owl2.model.builder import load_pretrained_model
22
+ from mplug_owl2.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23
+ from mplug_owl2.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
24
+ from transformers import TextIteratorStreamer
25
+ from threading import Thread
26
+
27
+
28
+ GB = 1 << 30
29
+
30
+ worker_id = str(uuid.uuid4())[:6]
31
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
+ global_counter = 0
33
+
34
+ model_semaphore = None
35
+
36
+
37
+ def heart_beat_worker(controller):
38
+
39
+ while True:
40
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
+ controller.send_heart_beat()
42
+
43
+
44
+ class ModelWorker:
45
+ def __init__(self, controller_addr, worker_addr,
46
+ worker_id, no_register,
47
+ model_path, model_base, model_name,
48
+ load_8bit, load_4bit, device):
49
+ self.controller_addr = controller_addr
50
+ self.worker_addr = worker_addr
51
+ self.worker_id = worker_id
52
+ if model_path.endswith("/"):
53
+ model_path = model_path[:-1]
54
+ if model_name is None:
55
+ model_paths = model_path.split("/")
56
+ if model_paths[-1].startswith('checkpoint-'):
57
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
+ else:
59
+ self.model_name = model_paths[-1]
60
+ else:
61
+ self.model_name = model_name
62
+
63
+ self.device = device
64
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
+ self.is_multimodal = True
68
+
69
+ if not no_register:
70
+ self.register_to_controller()
71
+ self.heart_beat_thread = threading.Thread(
72
+ target=heart_beat_worker, args=(self,))
73
+ self.heart_beat_thread.start()
74
+
75
+ def register_to_controller(self):
76
+ logger.info("Register to controller")
77
+
78
+ url = self.controller_addr + "/register_worker"
79
+ data = {
80
+ "worker_name": self.worker_addr,
81
+ "check_heart_beat": True,
82
+ "worker_status": self.get_status()
83
+ }
84
+ r = requests.post(url, json=data)
85
+ assert r.status_code == 200
86
+
87
+ def send_heart_beat(self):
88
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
+ f"global_counter: {global_counter}")
91
+
92
+ url = self.controller_addr + "/receive_heart_beat"
93
+
94
+ while True:
95
+ try:
96
+ ret = requests.post(url, json={
97
+ "worker_name": self.worker_addr,
98
+ "queue_length": self.get_queue_length()}, timeout=5)
99
+ exist = ret.json()["exist"]
100
+ break
101
+ except requests.exceptions.RequestException as e:
102
+ logger.error(f"heart beat error: {e}")
103
+ time.sleep(5)
104
+
105
+ if not exist:
106
+ self.register_to_controller()
107
+
108
+ def get_queue_length(self):
109
+ if model_semaphore is None:
110
+ return 0
111
+ else:
112
+ return args.limit_model_concurrency - model_semaphore._value + (len(
113
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
+
115
+ def get_status(self):
116
+ return {
117
+ "model_names": [self.model_name],
118
+ "speed": 1,
119
+ "queue_length": self.get_queue_length(),
120
+ }
121
+
122
+ @torch.inference_mode()
123
+ def generate_stream(self, params):
124
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
+
126
+ prompt = params["prompt"]
127
+ ori_prompt = prompt
128
+ images = params.get("images", None)
129
+ num_image_tokens = 0
130
+ if images is not None and len(images) > 0 and self.is_multimodal:
131
+ if len(images) > 0:
132
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
+ raise ValueError("Number of images does not match number of <|image|> tokens in prompt")
134
+
135
+ images = [load_image_from_base64(image) for image in images]
136
+ images = process_images(images, image_processor, model.config)
137
+
138
+ if type(images) is list:
139
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140
+ else:
141
+ images = images.to(self.model.device, dtype=torch.float16)
142
+
143
+ replace_token = DEFAULT_IMAGE_TOKEN
144
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
145
+
146
+ num_image_tokens = prompt.count(replace_token) * (model.get_model().visual_abstractor.config.num_learnable_queries + 1)
147
+ else:
148
+ images = None
149
+ image_args = {"images": images}
150
+ else:
151
+ images = None
152
+ image_args = {}
153
+
154
+ temperature = float(params.get("temperature", 1.0))
155
+ top_p = float(params.get("top_p", 1.0))
156
+ max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
157
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
158
+ stop_str = params.get("stop", None)
159
+ do_sample = True if temperature > 0.001 else False
160
+
161
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
162
+ keywords = [stop_str]
163
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
164
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
165
+
166
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
167
+
168
+ if max_new_tokens < 1:
169
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
170
+ return
171
+
172
+ thread = Thread(target=model.generate, kwargs=dict(
173
+ inputs=input_ids,
174
+ do_sample=do_sample,
175
+ temperature=temperature,
176
+ top_p=top_p,
177
+ max_new_tokens=max_new_tokens,
178
+ streamer=streamer,
179
+ stopping_criteria=[stopping_criteria],
180
+ use_cache=True,
181
+ **image_args
182
+ ))
183
+ thread.start()
184
+
185
+ generated_text = ori_prompt
186
+ for new_text in streamer:
187
+ generated_text += new_text
188
+ if generated_text.endswith(stop_str):
189
+ generated_text = generated_text[:-len(stop_str)]
190
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
191
+
192
+ def generate_stream_gate(self, params):
193
+ try:
194
+ for x in self.generate_stream(params):
195
+ yield x
196
+ except ValueError as e:
197
+ print("Caught ValueError:", e)
198
+ ret = {
199
+ "text": server_error_msg,
200
+ "error_code": 1,
201
+ }
202
+ yield json.dumps(ret).encode() + b"\0"
203
+ except torch.cuda.CudaError as e:
204
+ print("Caught torch.cuda.CudaError:", e)
205
+ ret = {
206
+ "text": server_error_msg,
207
+ "error_code": 1,
208
+ }
209
+ yield json.dumps(ret).encode() + b"\0"
210
+ except Exception as e:
211
+ print("Caught Unknown Error", e)
212
+ ret = {
213
+ "text": server_error_msg,
214
+ "error_code": 1,
215
+ }
216
+ yield json.dumps(ret).encode() + b"\0"
217
+
218
+ app = FastAPI()
219
+
220
+ def release_model_semaphore(fn=None):
221
+ model_semaphore.release()
222
+ if fn is not None:
223
+ fn()
224
+
225
+
226
+ @app.post("/worker_generate_stream")
227
+ async def generate_stream(request: Request):
228
+ global model_semaphore, global_counter
229
+ global_counter += 1
230
+ params = await request.json()
231
+
232
+ if model_semaphore is None:
233
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
234
+ await model_semaphore.acquire()
235
+ worker.send_heart_beat()
236
+ generator = worker.generate_stream_gate(params)
237
+ background_tasks = BackgroundTasks()
238
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
239
+ return StreamingResponse(generator, background=background_tasks)
240
+
241
+
242
+ @app.post("/worker_get_status")
243
+ async def get_status(request: Request):
244
+ return worker.get_status()
245
+
246
+
247
+ if __name__ == "__main__":
248
+ parser = argparse.ArgumentParser()
249
+ parser.add_argument("--host", type=str, default="localhost")
250
+ parser.add_argument("--port", type=int, default=21002)
251
+ parser.add_argument("--worker-address", type=str,
252
+ default="http://localhost:21002")
253
+ parser.add_argument("--controller-address", type=str,
254
+ default="http://localhost:21001")
255
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
256
+ parser.add_argument("--model-base", type=str, default=None)
257
+ parser.add_argument("--model-name", type=str)
258
+ parser.add_argument("--device", type=str, default="cuda")
259
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
260
+ parser.add_argument("--stream-interval", type=int, default=1)
261
+ parser.add_argument("--no-register", action="store_true")
262
+ parser.add_argument("--load-8bit", action="store_true")
263
+ parser.add_argument("--load-4bit", action="store_true")
264
+ args = parser.parse_args()
265
+ logger.info(f"args: {args}")
266
+
267
+
268
+ worker = ModelWorker(args.controller_address,
269
+ args.worker_address,
270
+ worker_id,
271
+ args.no_register,
272
+ args.model_path,
273
+ args.model_base,
274
+ args.model_name,
275
+ args.load_8bit,
276
+ args.load_4bit,
277
+ args.device)
278
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
mplug_docowl/serve/register_workers.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manually register workers.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--controller-address", type=str)
15
+ parser.add_argument("--worker-name", type=str)
16
+ parser.add_argument("--check-heart-beat", action="store_true")
17
+ args = parser.parse_args()
18
+
19
+ url = args.controller_address + "/register_worker"
20
+ data = {
21
+ "worker_name": args.worker_name,
22
+ "check_heart_beat": args.check_heart_beat,
23
+ "worker_status": None,
24
+ }
25
+ r = requests.post(url, json=data)
26
+ assert r.status_code == 200
mplug_docowl/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ import transformers
7
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8
+
9
+ try:
10
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11
+ except ImportError:
12
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13
+ from flash_attn.bert_padding import unpad_input, pad_input
14
+
15
+
16
+ def forward(
17
+ self,
18
+ hidden_states: torch.Tensor,
19
+ modality_indicators: torch.Tensor,
20
+ attention_mask: Optional[torch.Tensor] = None,
21
+ position_ids: Optional[torch.Tensor] = None,
22
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
23
+ output_attentions: bool = False,
24
+ use_cache: bool = False,
25
+ padding_mask: bool = None,
26
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
27
+ if output_attentions:
28
+ warnings.warn(
29
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
30
+ )
31
+
32
+ bsz, q_len, _ = hidden_states.size()
33
+
34
+ query_states = (
35
+ self.q_proj(hidden_states)
36
+ .view(bsz, q_len, self.num_heads, self.head_dim)
37
+ .transpose(1, 2)
38
+ )
39
+ key_states = (
40
+ self.k_proj(hidden_states, modality_indicators)
41
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
42
+ .transpose(1, 2)
43
+ )
44
+ value_states = (
45
+ self.v_proj(hidden_states, modality_indicators)
46
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
47
+ .transpose(1, 2)
48
+ ) # shape: (b, num_heads, s, head_dim)
49
+
50
+ kv_seq_len = key_states.shape[-2]
51
+ if past_key_value is not None:
52
+ kv_seq_len += past_key_value[0].shape[-2]
53
+
54
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55
+ query_states, key_states = apply_rotary_pos_emb(
56
+ query_states, key_states, cos, sin, position_ids
57
+ )
58
+
59
+ if past_key_value is not None:
60
+ # reuse k, v
61
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
62
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
63
+
64
+ past_key_value = (key_states, value_states) if use_cache else None
65
+
66
+ # repeat k/v heads if n_kv_heads < n_heads
67
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
68
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
69
+
70
+ # Transform the data into the format required by flash attention
71
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
72
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
73
+ key_padding_mask = attention_mask
74
+
75
+ if key_padding_mask is None:
76
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
77
+ cu_q_lens = torch.arange(
78
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
79
+ )
80
+ max_s = q_len
81
+ output = flash_attn_unpadded_qkvpacked_func(
82
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
83
+ )
84
+ output = output.view(bsz, q_len, -1)
85
+ else:
86
+ qkv = qkv.reshape(bsz, q_len, -1)
87
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
88
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
89
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
90
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
91
+ )
92
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
93
+ output = pad_input(output_unpad, indices, bsz, q_len)
94
+
95
+ return self.o_proj(output), None, past_key_value
96
+
97
+
98
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
99
+ # requires the attention mask to be the same as the key_padding_mask
100
+ def _prepare_decoder_attention_mask(
101
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
102
+ ):
103
+ # [bsz, seq_len]
104
+ return attention_mask
105
+
106
+
107
+ def replace_llama_attn_with_flash_attn():
108
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
109
+ if cuda_major < 8:
110
+ warnings.warn(
111
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
112
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
113
+ )
114
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
115
+ _prepare_decoder_attention_mask
116
+ )
117
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
mplug_docowl/train/mplug_owl2_trainer.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from torch.utils.data import Sampler
5
+
6
+ from transformers import Trainer
7
+ from transformers.trainer import (
8
+ is_sagemaker_mp_enabled,
9
+ get_parameter_names,
10
+ has_length,
11
+ ALL_LAYERNORM_LAYERS,
12
+ ShardedDDPOption,
13
+ logger,
14
+ )
15
+ from typing import List, Optional
16
+ from icecream import ic
17
+
18
+ def maybe_zero_3(param, ignore_status=False, name=None):
19
+ from deepspeed import zero
20
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
21
+ if hasattr(param, "ds_id"):
22
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
23
+ if not ignore_status:
24
+ print(name, 'no ignore status')
25
+ with zero.GatheredParameters([param]):
26
+ param = param.data.detach().cpu().clone()
27
+ else:
28
+ param = param.detach().cpu().clone()
29
+ return param
30
+
31
+
32
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
33
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
34
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
35
+ return to_return
36
+
37
+
38
+ def split_to_even_chunks(indices, lengths, num_chunks):
39
+ """
40
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
41
+ """
42
+
43
+ if len(indices) % num_chunks != 0:
44
+ return [indices[i::num_chunks] for i in range(num_chunks)]
45
+
46
+ num_indices_per_chunk = len(indices) // num_chunks
47
+
48
+ chunks = [[] for _ in range(num_chunks)]
49
+ chunks_lengths = [0 for _ in range(num_chunks)]
50
+ for index in indices:
51
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
52
+ chunks[shortest_chunk].append(index)
53
+ chunks_lengths[shortest_chunk] += lengths[index]
54
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
55
+ chunks_lengths[shortest_chunk] = float("inf")
56
+
57
+ return chunks
58
+
59
+
60
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
61
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
62
+ assert all(l != 0 for l in lengths), "Should not have zero length."
63
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
64
+ # all samples are in the same modality
65
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
66
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
67
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
68
+
69
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
70
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
71
+ megabatch_size = world_size * batch_size
72
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
73
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
74
+
75
+ last_mm = mm_megabatches[-1]
76
+ last_lang = lang_megabatches[-1]
77
+ additional_batch = last_mm + last_lang
78
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
79
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
80
+ megabatches = [megabatches[i] for i in megabatch_indices]
81
+
82
+ if len(additional_batch) > 0:
83
+ megabatches.append(sorted(additional_batch))
84
+
85
+ return [i for megabatch in megabatches for i in megabatch]
86
+
87
+
88
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
89
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
90
+ indices = torch.randperm(len(lengths), generator=generator)
91
+ megabatch_size = world_size * batch_size
92
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
93
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
94
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
95
+
96
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
97
+
98
+
99
+ class LengthGroupedSampler(Sampler):
100
+ r"""
101
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
102
+ keeping a bit of randomness.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ batch_size: int,
108
+ world_size: int,
109
+ lengths: Optional[List[int]] = None,
110
+ generator=None,
111
+ group_by_modality: bool = False,
112
+ ):
113
+ if lengths is None:
114
+ raise ValueError("Lengths must be provided.")
115
+
116
+ self.batch_size = batch_size
117
+ self.world_size = world_size
118
+ self.lengths = lengths
119
+ self.generator = generator
120
+ self.group_by_modality = group_by_modality
121
+
122
+ def __len__(self):
123
+ return len(self.lengths)
124
+
125
+ def __iter__(self):
126
+ if self.group_by_modality:
127
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
128
+ else:
129
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
130
+ return iter(indices)
131
+
132
+
133
+ class MPLUGOwl2Trainer(Trainer):
134
+
135
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
136
+ if self.train_dataset is None or not has_length(self.train_dataset):
137
+ return None
138
+
139
+ if self.args.group_by_modality_length:
140
+ lengths = self.train_dataset.modality_lengths
141
+ return LengthGroupedSampler(
142
+ self.args.train_batch_size,
143
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
144
+ lengths=lengths,
145
+ group_by_modality=True,
146
+ )
147
+ else:
148
+ return super()._get_train_sampler()
149
+
150
+ def create_optimizer(self):
151
+ """
152
+ Setup the optimizer.
153
+
154
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
155
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
156
+ """
157
+ if is_sagemaker_mp_enabled():
158
+ return super().create_optimizer()
159
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
160
+ return super().create_optimizer()
161
+
162
+ opt_model = self.model
163
+
164
+ if self.optimizer is None:
165
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
166
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
167
+ if self.args.visual_abstractor_lr is not None:
168
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "visual_abstractor_lr" in name]
169
+ optimizer_grouped_parameters = [
170
+ {
171
+ "params": [
172
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
173
+ ],
174
+ "weight_decay": self.args.weight_decay,
175
+ },
176
+ {
177
+ "params": [
178
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
179
+ ],
180
+ "weight_decay": 0.0,
181
+ },
182
+ {
183
+ "params": [
184
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
185
+ ],
186
+ "weight_decay": self.args.weight_decay,
187
+ "lr": self.args.visual_abstractor_lr,
188
+ },
189
+ {
190
+ "params": [
191
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
192
+ ],
193
+ "weight_decay": 0.0,
194
+ "lr": self.args.visual_abstractor_lr,
195
+ },
196
+ ]
197
+ else:
198
+ optimizer_grouped_parameters = [
199
+ {
200
+ "params": [
201
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
202
+ ],
203
+ "weight_decay": self.args.weight_decay,
204
+ },
205
+ {
206
+ "params": [
207
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
208
+ ],
209
+ "weight_decay": 0.0,
210
+ },
211
+ ]
212
+ ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params']))
213
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
214
+
215
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
216
+ self.optimizer = OSS(
217
+ params=optimizer_grouped_parameters,
218
+ optim=optimizer_cls,
219
+ **optimizer_kwargs,
220
+ )
221
+ else:
222
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
223
+ if optimizer_cls.__name__ == "Adam8bit":
224
+ import bitsandbytes
225
+
226
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
227
+
228
+ skipped = 0
229
+ for module in opt_model.modules():
230
+ if isinstance(module, nn.Embedding):
231
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
232
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
233
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
234
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
235
+ logger.info(f"skipped: {skipped/2**20}M params")
236
+
237
+ return self.optimizer
238
+
239
+ def _save_checkpoint(self, model, trial, metrics=None):
240
+ super(MPLUGOwl2Trainer, self)._save_checkpoint(model, trial, metrics)
241
+
242
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
243
+ super(MPLUGOwl2Trainer, self)._save(output_dir, state_dict)
mplug_docowl/train/train.py ADDED
@@ -0,0 +1,801 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import copy
19
+ from dataclasses import dataclass, field
20
+ import json
21
+ import logging
22
+ import pathlib
23
+ from typing import Dict, Optional, Sequence, List
24
+
25
+ import torch
26
+
27
+ import transformers
28
+ from transformers.models.clip.image_processing_clip import CLIPImageProcessor
29
+
30
+ from torch.utils.data import Dataset
31
+ from mplug_owl2.train.mplug_owl2_trainer import MPLUGOwl2Trainer
32
+ from mplug_owl2.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
33
+
34
+ from mplug_owl2 import conversation as conversation_lib
35
+ from mplug_owl2.model import *
36
+ from mplug_owl2.mm_utils import tokenizer_image_token
37
+
38
+ from PIL import Image
39
+ from icecream import ic
40
+
41
+ local_rank = None
42
+
43
+
44
+ def rank0_print(*args):
45
+ if local_rank == 0:
46
+ print(*args)
47
+
48
+
49
+ @dataclass
50
+ class ModelArguments:
51
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
52
+ version: Optional[str] = field(default="v0")
53
+ freeze_backbone: bool = field(default=False)
54
+
55
+ @dataclass
56
+ class DataArguments:
57
+ data_path: str = field(default=None,
58
+ metadata={"help": "Path to the training data."})
59
+ lazy_preprocess: bool = False
60
+ is_multimodal: bool = False
61
+ image_folder: Optional[str] = field(default=None)
62
+ image_aspect_ratio: str = 'square'
63
+ image_grid_pinpoints: Optional[str] = field(default=None)
64
+
65
+
66
+ @dataclass
67
+ class TrainingArguments(transformers.TrainingArguments):
68
+ cache_dir: Optional[str] = field(default=None)
69
+ optim: str = field(default="adamw_torch")
70
+ remove_unused_columns: bool = field(default=False)
71
+
72
+ tune_visual_abstractor: bool = field(default=True)
73
+ freeze_vision_model: bool = field(default=True)
74
+
75
+ model_max_length: int = field(
76
+ default=512,
77
+ metadata={
78
+ "help":
79
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
80
+ },
81
+ )
82
+ double_quant: bool = field(
83
+ default=True,
84
+ metadata={"help": "Compress the quantization statistics through double quantization."}
85
+ )
86
+ quant_type: str = field(
87
+ default="nf4",
88
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
89
+ )
90
+ bits: int = field(
91
+ default=16,
92
+ metadata={"help": "How many bits to use."}
93
+ )
94
+ lora_enable: bool = False
95
+ lora_r: int = 64
96
+ lora_alpha: int = 16
97
+ lora_dropout: float = 0.05
98
+ lora_weight_path: str = ""
99
+ lora_bias: str = "none"
100
+ visual_abstractor_lr: Optional[float] = None
101
+ group_by_modality_length: bool = field(default=False)
102
+
103
+
104
+ def maybe_zero_3(param, ignore_status=False, name=None):
105
+ from deepspeed import zero
106
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
107
+ if hasattr(param, "ds_id"):
108
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
109
+ if not ignore_status:
110
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
111
+ with zero.GatheredParameters([param]):
112
+ param = param.data.detach().cpu().clone()
113
+ else:
114
+ param = param.detach().cpu().clone()
115
+ return param
116
+
117
+
118
+ # Borrowed from peft.utils.get_peft_model_state_dict
119
+ def get_peft_state_maybe_zero_3(named_params, bias):
120
+ if bias == "none":
121
+ to_return = {k: t for k, t in named_params if "lora_" in k}
122
+ elif bias == "all":
123
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
124
+ elif bias == "lora_only":
125
+ to_return = {}
126
+ maybe_lora_bias = {}
127
+ lora_bias_names = set()
128
+ for k, t in named_params:
129
+ if "lora_" in k:
130
+ to_return[k] = t
131
+ bias_name = k.split("lora_")[0] + "bias"
132
+ lora_bias_names.add(bias_name)
133
+ elif "bias" in k:
134
+ maybe_lora_bias[k] = t
135
+ for k, t in maybe_lora_bias:
136
+ if bias_name in lora_bias_names:
137
+ to_return[bias_name] = t
138
+ else:
139
+ raise NotImplementedError
140
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
141
+ return to_return
142
+
143
+
144
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
145
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
146
+ if require_grad_only:
147
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
148
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
149
+ return to_return
150
+
151
+
152
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
153
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
154
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
155
+ return to_return
156
+
157
+
158
+ def find_all_linear_names(model):
159
+ cls = torch.nn.Linear
160
+ lora_module_names = set()
161
+ multimodal_keywords = ['vision_model', 'visual_abstractor']
162
+ for name, module in model.named_modules():
163
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
164
+ continue
165
+ if isinstance(module, cls):
166
+ lora_module_names.add(name)
167
+
168
+ if 'lm_head' in lora_module_names: # needed for 16-bit
169
+ lora_module_names.remove('lm_head')
170
+ return list(lora_module_names)
171
+
172
+
173
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
174
+ output_dir: str):
175
+ """Collects the state dict and dump to disk."""
176
+
177
+ if trainer.deepspeed:
178
+ torch.cuda.synchronize()
179
+ trainer.save_model(output_dir)
180
+ return
181
+
182
+ state_dict = trainer.model.state_dict()
183
+ if trainer.args.should_save:
184
+ cpu_state_dict = {
185
+ key: value.cpu()
186
+ for key, value in state_dict.items()
187
+ }
188
+ del state_dict
189
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
190
+
191
+
192
+ def smart_tokenizer_and_embedding_resize(
193
+ special_tokens_dict: Dict,
194
+ tokenizer: transformers.PreTrainedTokenizer,
195
+ model: transformers.PreTrainedModel,
196
+ ):
197
+ """Resize tokenizer and embedding.
198
+
199
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
200
+ """
201
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
202
+ model.resize_token_embeddings(len(tokenizer))
203
+
204
+ if num_new_tokens > 0:
205
+ input_embeddings = model.get_input_embeddings().weight.data
206
+ output_embeddings = model.get_output_embeddings().weight.data
207
+
208
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
209
+ dim=0, keepdim=True)
210
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
211
+ dim=0, keepdim=True)
212
+
213
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
214
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
215
+
216
+
217
+ def _tokenize_fn(strings: Sequence[str],
218
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
219
+ """Tokenize a list of strings."""
220
+ tokenized_list = [
221
+ tokenizer(
222
+ text,
223
+ return_tensors="pt",
224
+ padding="longest",
225
+ max_length=tokenizer.model_max_length,
226
+ truncation=True,
227
+ ) for text in strings
228
+ ]
229
+ input_ids = labels = [
230
+ tokenized.input_ids[0] for tokenized in tokenized_list
231
+ ]
232
+ input_ids_lens = labels_lens = [
233
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
234
+ for tokenized in tokenized_list
235
+ ]
236
+ return dict(
237
+ input_ids=input_ids,
238
+ labels=labels,
239
+ input_ids_lens=input_ids_lens,
240
+ labels_lens=labels_lens,
241
+ )
242
+
243
+
244
+ def _mask_targets(target, tokenized_lens, speakers):
245
+ # cur_idx = 0
246
+ cur_idx = tokenized_lens[0]
247
+ tokenized_lens = tokenized_lens[1:]
248
+ target[:cur_idx] = IGNORE_INDEX
249
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
250
+ if speaker == "human":
251
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
252
+ cur_idx += tokenized_len
253
+
254
+
255
+ def _add_speaker_and_signal(header, source, get_conversation=True):
256
+ """Add speaker and start/end signal on each round."""
257
+ BEGIN_SIGNAL = "### "
258
+ END_SIGNAL = "\n"
259
+ conversation = header
260
+ for sentence in source:
261
+ from_str = sentence["from"]
262
+ if from_str.lower() == "human":
263
+ from_str = conversation_lib.default_conversation.roles[0]
264
+ elif from_str.lower() == "gpt":
265
+ from_str = conversation_lib.default_conversation.roles[1]
266
+ else:
267
+ from_str = 'unknown'
268
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
269
+ sentence["value"] + END_SIGNAL)
270
+ if get_conversation:
271
+ conversation += sentence["value"]
272
+ conversation += BEGIN_SIGNAL
273
+ return conversation
274
+
275
+
276
+ def preprocess_multimodal(
277
+ sources: Sequence[str],
278
+ data_args: DataArguments
279
+ ) -> Dict:
280
+ is_multimodal = data_args.is_multimodal
281
+ if not is_multimodal:
282
+ return sources
283
+
284
+ for source in sources:
285
+ for sentence in source:
286
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
287
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
288
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
289
+ sentence['value'] = sentence['value'].strip()
290
+
291
+ replace_token = DEFAULT_IMAGE_TOKEN
292
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
293
+
294
+ return sources
295
+
296
+
297
+ def preprocess_v1(
298
+ sources,
299
+ tokenizer: transformers.PreTrainedTokenizer,
300
+ has_image: bool = False
301
+ ) -> Dict:
302
+ conv = conversation_lib.default_conversation.copy()
303
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
304
+
305
+ # Apply prompt templates
306
+ conversations = []
307
+ for i, source in enumerate(sources):
308
+ if roles[source[0]["from"]] != conv.roles[0]:
309
+ # Skip the first one if it is not from human
310
+ source = source[1:]
311
+
312
+ conv.messages = []
313
+ for j, sentence in enumerate(source):
314
+ role = roles[sentence["from"]]
315
+ assert role == conv.roles[j % 2], f"{i}"
316
+ conv.append_message(role, sentence["value"])
317
+ conversations.append(conv.get_prompt())
318
+
319
+ # Tokenize conversations
320
+
321
+ if has_image:
322
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
323
+ else:
324
+ input_ids = tokenizer(
325
+ conversations,
326
+ return_tensors="pt",
327
+ padding="longest",
328
+ max_length=tokenizer.model_max_length,
329
+ truncation=True,
330
+ ).input_ids
331
+
332
+ targets = input_ids.clone()
333
+
334
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
335
+
336
+ # Mask targets
337
+ sep = conv.sep + conv.roles[1] + ": "
338
+ for conversation, target in zip(conversations, targets):
339
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
340
+
341
+ rounds = conversation.split(conv.sep2)
342
+ cur_len = 1
343
+ target[:cur_len] = IGNORE_INDEX
344
+ for i, rou in enumerate(rounds):
345
+ if rou == "":
346
+ break
347
+
348
+ parts = rou.split(sep)
349
+ if len(parts) != 2:
350
+ break
351
+ parts[0] += sep
352
+
353
+ if has_image:
354
+ round_len = len(tokenizer_image_token(rou, tokenizer))
355
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
356
+ else:
357
+ round_len = len(tokenizer(rou).input_ids)
358
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
359
+
360
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
361
+
362
+ cur_len += round_len
363
+ target[cur_len:] = IGNORE_INDEX
364
+
365
+ if cur_len < tokenizer.model_max_length:
366
+ if cur_len != total_len:
367
+ target[:] = IGNORE_INDEX
368
+ print(
369
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
370
+ f" (ignored)"
371
+ )
372
+
373
+ return dict(
374
+ input_ids=input_ids,
375
+ labels=targets,
376
+ )
377
+
378
+
379
+ def preprocess_plain(
380
+ sources: Sequence[str],
381
+ tokenizer: transformers.PreTrainedTokenizer,
382
+ ) -> Dict:
383
+ # add end signal and concatenate together
384
+ conversations = []
385
+ for source in sources:
386
+ assert len(source) == 2
387
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
388
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
389
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
390
+ conversations.append(conversation)
391
+ # tokenize conversations
392
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
393
+ targets = copy.deepcopy(input_ids)
394
+ for target, source in zip(targets, sources):
395
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
396
+ target[:tokenized_len] = IGNORE_INDEX
397
+
398
+ return dict(input_ids=input_ids, labels=targets)
399
+
400
+
401
+ def preprocess(
402
+ sources: Sequence[str],
403
+ tokenizer: transformers.PreTrainedTokenizer,
404
+ has_image: bool = False
405
+ ) -> Dict:
406
+ """
407
+ Given a list of sources, each is a conversation list. This transform:
408
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
409
+ 2. Concatenate conversations together;
410
+ 3. Tokenize the concatenated conversation;
411
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
412
+ """
413
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
414
+ return preprocess_plain(sources, tokenizer)
415
+ if conversation_lib.default_conversation.version.startswith("v1"):
416
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
417
+ # add end signal and concatenate together
418
+ conversations = []
419
+ for source in sources:
420
+ header = f"{conversation_lib.default_conversation.system}\n\n"
421
+ conversation = _add_speaker_and_signal(header, source)
422
+ conversations.append(conversation)
423
+ # tokenize conversations
424
+ def get_tokenize_len(prompts):
425
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
426
+ if has_image:
427
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
428
+ else:
429
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
430
+ input_ids = conversations_tokenized["input_ids"]
431
+
432
+ targets = copy.deepcopy(input_ids)
433
+ for target, source in zip(targets, sources):
434
+ if has_image:
435
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
436
+ else:
437
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
438
+ speakers = [sentence["from"] for sentence in source]
439
+ _mask_targets(target, tokenized_lens, speakers)
440
+
441
+ return dict(input_ids=input_ids, labels=targets)
442
+
443
+
444
+ class LazySupervisedDataset(Dataset):
445
+ """Dataset for supervised fine-tuning."""
446
+
447
+ def __init__(self, data_path: str,
448
+ tokenizer: transformers.PreTrainedTokenizer,
449
+ data_args: DataArguments):
450
+ super(LazySupervisedDataset, self).__init__()
451
+ list_data_dict = json.load(open(data_path, "r"))
452
+
453
+ rank0_print("Formatting inputs...Skip in lazy mode")
454
+ self.tokenizer = tokenizer
455
+ self.list_data_dict = list_data_dict
456
+ self.data_args = data_args
457
+
458
+ def __len__(self):
459
+ return len(self.list_data_dict)
460
+
461
+ @property
462
+ def lengths(self):
463
+ length_list = []
464
+ for sample in self.list_data_dict:
465
+ img_tokens = 128 if 'image' in sample else 0
466
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
467
+ return length_list
468
+
469
+
470
+ @property
471
+ def modality_lengths(self):
472
+ length_list = []
473
+ for sample in self.list_data_dict:
474
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
475
+ cur_len = cur_len if 'image' in sample else -cur_len
476
+ length_list.append(cur_len)
477
+ return length_list
478
+
479
+ # def __getitem__(self, i) -> Dict[str, torch.Tensor]:
480
+ # sources = self.list_data_dict[i]
481
+ # if isinstance(i, int):
482
+ # sources = [sources]
483
+ # assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
484
+ # if 'image' in sources[0]:
485
+ # image_file = self.list_data_dict[i]['image']
486
+ # image_folder = self.data_args.image_folder
487
+ # processor = self.data_args.image_processor
488
+ # image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
489
+ # if self.data_args.image_aspect_ratio == 'pad':
490
+ # def expand2square(pil_img, background_color):
491
+ # width, height = pil_img.size
492
+ # if width == height:
493
+ # return pil_img
494
+ # elif width > height:
495
+ # result = Image.new(pil_img.mode, (width, width), background_color)
496
+ # result.paste(pil_img, (0, (width - height) // 2))
497
+ # return result
498
+ # else:
499
+ # result = Image.new(pil_img.mode, (height, height), background_color)
500
+ # result.paste(pil_img, ((height - width) // 2, 0))
501
+ # return result
502
+ # image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
503
+ # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
504
+ # else:
505
+ # image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
506
+ # sources = preprocess_multimodal(
507
+ # copy.deepcopy([e["conversations"] for e in sources]),
508
+ # self.data_args)
509
+ # else:
510
+ # sources = copy.deepcopy([e["conversations"] for e in sources])
511
+ # data_dict = preprocess(
512
+ # sources,
513
+ # self.tokenizer,
514
+ # has_image=('image' in self.list_data_dict[i]))
515
+ # if isinstance(i, int):
516
+ # data_dict = dict(input_ids=data_dict["input_ids"][0],
517
+ # labels=data_dict["labels"][0])
518
+
519
+ # # image exist in the data
520
+ # if 'image' in self.list_data_dict[i]:
521
+ # data_dict['image'] = image
522
+ # elif self.data_args.is_multimodal:
523
+ # # image does not exist in the data, but the model is multimodal
524
+ # crop_size = self.data_args.image_processor.crop_size
525
+ # data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
526
+ # return data_dict
527
+
528
+ def next_rand(self):
529
+ import random
530
+ return random.randint(0,len(self)-1)
531
+
532
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
533
+ while True:
534
+ sources = self.list_data_dict[i]
535
+ if isinstance(i, int):
536
+ sources = [sources]
537
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
538
+ if 'image' in sources[0]:
539
+
540
+ image_file = self.list_data_dict[i]['image']
541
+ image_folder = self.data_args.image_folder
542
+ processor = self.data_args.image_processor
543
+ from pathlib import Path
544
+ if not Path(os.path.join(image_folder, image_file)).exists():
545
+ i = self.next_rand()
546
+ continue
547
+ image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
548
+ if self.data_args.image_aspect_ratio == 'pad':
549
+ def expand2square(pil_img, background_color):
550
+ width, height = pil_img.size
551
+ if width == height:
552
+ return pil_img
553
+ elif width > height:
554
+ result = Image.new(pil_img.mode, (width, width), background_color)
555
+ result.paste(pil_img, (0, (width - height) // 2))
556
+ return result
557
+ else:
558
+ result = Image.new(pil_img.mode, (height, height), background_color)
559
+ result.paste(pil_img, ((height - width) // 2, 0))
560
+ return result
561
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
562
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
563
+ else:
564
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
565
+ sources = preprocess_multimodal(
566
+ copy.deepcopy([e["conversations"] for e in sources]),
567
+ self.data_args)
568
+ else:
569
+
570
+ sources = copy.deepcopy([e["conversations"] for e in sources])
571
+ data_dict = preprocess(
572
+ sources,
573
+ self.tokenizer,
574
+ has_image=('image' in self.list_data_dict[i]))
575
+ if isinstance(i, int):
576
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
577
+ labels=data_dict["labels"][0])
578
+
579
+ # image exist in the data
580
+ if 'image' in self.list_data_dict[i]:
581
+ data_dict['image'] = image
582
+ elif self.data_args.is_multimodal:
583
+ # image does not exist in the data, but the model is multimodal
584
+ crop_size = self.data_args.image_processor.crop_size
585
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
586
+ return data_dict
587
+
588
+
589
+ @dataclass
590
+ class DataCollatorForSupervisedDataset(object):
591
+ """Collate examples for supervised fine-tuning."""
592
+
593
+ tokenizer: transformers.PreTrainedTokenizer
594
+
595
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
596
+ input_ids, labels = tuple([instance[key] for instance in instances]
597
+ for key in ("input_ids", "labels"))
598
+ input_ids = torch.nn.utils.rnn.pad_sequence(
599
+ input_ids,
600
+ batch_first=True,
601
+ padding_value=self.tokenizer.pad_token_id)
602
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
603
+ batch_first=True,
604
+ padding_value=IGNORE_INDEX)
605
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
606
+ labels = labels[:, :self.tokenizer.model_max_length]
607
+ batch = dict(
608
+ input_ids=input_ids,
609
+ labels=labels,
610
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
611
+ )
612
+
613
+ if 'image' in instances[0]:
614
+ images = [instance['image'] for instance in instances]
615
+ if all(x is not None and x.shape == images[0].shape for x in images):
616
+ batch['images'] = torch.stack(images)
617
+ else:
618
+ batch['images'] = images
619
+
620
+ return batch
621
+
622
+
623
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
624
+ data_args) -> Dict:
625
+ """Make dataset and collator for supervised fine-tuning."""
626
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
627
+ data_path=data_args.data_path,
628
+ data_args=data_args)
629
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
630
+ return dict(train_dataset=train_dataset,
631
+ eval_dataset=None,
632
+ data_collator=data_collator)
633
+
634
+
635
+ def train():
636
+ global local_rank
637
+
638
+ parser = transformers.HfArgumentParser(
639
+ (ModelArguments, DataArguments, TrainingArguments))
640
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
641
+ local_rank = training_args.local_rank
642
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
643
+
644
+ bnb_model_from_pretrained_args = {}
645
+ if training_args.bits in [4, 8]:
646
+ from transformers import BitsAndBytesConfig
647
+ bnb_model_from_pretrained_args.update(dict(
648
+ device_map={"": training_args.device},
649
+ load_in_4bit=training_args.bits == 4,
650
+ load_in_8bit=training_args.bits == 8,
651
+ quantization_config=BitsAndBytesConfig(
652
+ load_in_4bit=training_args.bits == 4,
653
+ load_in_8bit=training_args.bits == 8,
654
+ llm_int8_threshold=6.0,
655
+ llm_int8_has_fp16_weight=False,
656
+ bnb_4bit_compute_dtype=compute_dtype,
657
+ bnb_4bit_use_double_quant=training_args.double_quant,
658
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
659
+ )
660
+ ))
661
+
662
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
663
+ model_args.model_name_or_path,
664
+ cache_dir=training_args.cache_dir,
665
+ **bnb_model_from_pretrained_args
666
+ )
667
+ model.config.use_cache = False
668
+
669
+ if model_args.freeze_backbone:
670
+ model.model.requires_grad_(False)
671
+
672
+ if training_args.bits in [4, 8]:
673
+ from peft import prepare_model_for_kbit_training
674
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
675
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
676
+
677
+ if training_args.gradient_checkpointing:
678
+ if hasattr(model, "enable_input_require_grads"):
679
+ model.enable_input_require_grads()
680
+ else:
681
+ def make_inputs_require_grad(module, input, output):
682
+ output.requires_grad_(True)
683
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
684
+
685
+ if training_args.lora_enable:
686
+ from peft import LoraConfig, get_peft_model
687
+ lora_config = LoraConfig(
688
+ r=training_args.lora_r,
689
+ lora_alpha=training_args.lora_alpha,
690
+ target_modules=find_all_linear_names(model),
691
+ lora_dropout=training_args.lora_dropout,
692
+ bias=training_args.lora_bias,
693
+ task_type="CAUSAL_LM",
694
+ )
695
+ if training_args.bits == 16:
696
+ if training_args.bf16:
697
+ model.to(torch.bfloat16)
698
+ if training_args.fp16:
699
+ model.to(torch.float16)
700
+ rank0_print("Adding LoRA adapters...")
701
+ model = get_peft_model(model, lora_config)
702
+
703
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
704
+ model_args.model_name_or_path,
705
+ cache_dir=training_args.cache_dir,
706
+ model_max_length=training_args.model_max_length,
707
+ padding_side="right",
708
+ use_fast=False,
709
+ )
710
+
711
+
712
+ tokenizer.pad_token = tokenizer.unk_token
713
+ if model_args.version in conversation_lib.conv_templates:
714
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
715
+ else:
716
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
717
+
718
+ if not training_args.freeze_vision_model and training_args.bits in [4, 8]:
719
+ model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device)
720
+ else:
721
+ vision_tower = model.get_model().vision_model
722
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
723
+
724
+ if training_args.tune_visual_abstractor and training_args.bits in [4, 8]:
725
+ model.get_model().visual_abstractor.to(dtype=compute_dtype, device=training_args.device)
726
+ else:
727
+ visual_abstractor = model.get_model().visual_abstractor
728
+ visual_abstractor.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
729
+
730
+ data_args.image_processor = CLIPImageProcessor.from_pretrained(model_args.model_name_or_path)
731
+ data_args.is_multimodal = True
732
+
733
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
734
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
735
+ model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = training_args.tune_visual_abstractor
736
+ ic(training_args.tune_visual_abstractor)
737
+ model.requires_grad_(True)
738
+ if training_args.tune_visual_abstractor:
739
+ # model.requires_grad_(False)
740
+ for p in model.get_model().visual_abstractor.parameters():
741
+ p.requires_grad = True
742
+
743
+ model.config.freeze_vision_model = training_args.freeze_vision_model
744
+ ic(training_args.freeze_vision_model)
745
+ if training_args.freeze_vision_model:
746
+ for p in model.get_model().vision_model.parameters():
747
+ p.requires_grad = False
748
+
749
+ model.config.visual_abstractor_lr = training_args.visual_abstractor_lr
750
+
751
+
752
+ if training_args.bits in [4, 8]:
753
+ from peft.tuners.lora import LoraLayer
754
+ for name, module in model.named_modules():
755
+ if isinstance(module, LoraLayer):
756
+ if training_args.bf16:
757
+ module = module.to(torch.bfloat16)
758
+ if 'norm' in name:
759
+ module = module.to(torch.float32)
760
+ if 'lm_head' in name or 'embed_tokens' in name:
761
+ if hasattr(module, 'weight'):
762
+ if training_args.bf16 and module.weight.dtype == torch.float32:
763
+ module = module.to(torch.bfloat16)
764
+
765
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
766
+ data_args=data_args)
767
+ trainer = MPLUGOwl2Trainer(model=model,
768
+ tokenizer=tokenizer,
769
+ args=training_args,
770
+ **data_module)
771
+
772
+ # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
773
+ # trainer.train(resume_from_checkpoint=True)
774
+ # else:
775
+ # trainer.train()
776
+
777
+ # TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE
778
+ trainer.train()
779
+
780
+ trainer.save_state()
781
+
782
+ model.config.use_cache = True
783
+
784
+ if training_args.lora_enable:
785
+ state_dict = get_peft_state_maybe_zero_3(
786
+ model.named_parameters(), training_args.lora_bias
787
+ )
788
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
789
+ model.named_parameters()
790
+ )
791
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
792
+ model.config.save_pretrained(training_args.output_dir)
793
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
794
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
795
+ else:
796
+ safe_save_model_for_hf_trainer(trainer=trainer,
797
+ output_dir=training_args.output_dir)
798
+
799
+
800
+ if __name__ == "__main__":
801
+ train()