manysuch-cases
commited on
Upload 3 files
Browse files- apollo/constants.py +31 -0
- apollo/conversation.py +544 -0
- apollo/mm_utils.py +579 -0
apollo/constants.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
17 |
+
|
18 |
+
|
19 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
20 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
21 |
+
|
22 |
+
LOGDIR = "."
|
23 |
+
|
24 |
+
|
25 |
+
# Model Constants
|
26 |
+
IGNORE_INDEX = -100
|
27 |
+
X_TOKEN_INDEX = -200
|
28 |
+
X_TOKEN = {'image': "<|image_token|>", 'video': "<|video_token|>"}
|
29 |
+
X_PATCH_TOKEN = {'image': "<|image_patch|>", 'video': "<|video_patch|>"}
|
30 |
+
X_START_TOKEN = {'image': "<|image_start|>", 'video': "<|video_start|>"}
|
31 |
+
X_END_TOKEN = {'image': "<|image_end|>", 'video': "<|video_end|>"}
|
apollo/conversation.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
17 |
+
|
18 |
+
|
19 |
+
import dataclasses
|
20 |
+
from enum import auto, Enum
|
21 |
+
from typing import List, Tuple
|
22 |
+
|
23 |
+
|
24 |
+
class SeparatorStyle(Enum):
|
25 |
+
"""Different separator style."""
|
26 |
+
SINGLE = auto()
|
27 |
+
TWO = auto()
|
28 |
+
MPT = auto()
|
29 |
+
PLAIN = auto()
|
30 |
+
LLAMA_2 = auto()
|
31 |
+
LLAMA_3 = auto()
|
32 |
+
MISTRAL = auto()
|
33 |
+
CHATML = auto()
|
34 |
+
QWEN = auto()
|
35 |
+
QWEN_2 = auto()
|
36 |
+
GEMMA = auto()
|
37 |
+
|
38 |
+
|
39 |
+
@dataclasses.dataclass
|
40 |
+
class Conversation:
|
41 |
+
"""A class that keeps all conversation history."""
|
42 |
+
system: str
|
43 |
+
roles: List[str]
|
44 |
+
messages: List[List[str]]
|
45 |
+
offset: int
|
46 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
47 |
+
sep: str = "###"
|
48 |
+
sep2: str = None
|
49 |
+
version: str = "Unknown"
|
50 |
+
|
51 |
+
skip_next: bool = False
|
52 |
+
|
53 |
+
def get_prompt(self):
|
54 |
+
messages = self.messages
|
55 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
56 |
+
messages = self.messages.copy()
|
57 |
+
init_role, init_msg = messages[0].copy()
|
58 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
59 |
+
if 'mmtag' in self.version:
|
60 |
+
messages[0] = (init_role, init_msg)
|
61 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
62 |
+
messages.insert(1, (self.roles[1], "Received."))
|
63 |
+
else:
|
64 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
65 |
+
|
66 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
67 |
+
ret = self.system + self.sep
|
68 |
+
for role, message in messages:
|
69 |
+
if message:
|
70 |
+
if type(message) is tuple:
|
71 |
+
message, _, _ = message
|
72 |
+
ret += role + ": " + message + self.sep
|
73 |
+
else:
|
74 |
+
ret += role + ":"
|
75 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
76 |
+
seps = [self.sep, self.sep2]
|
77 |
+
ret = self.system + seps[0]
|
78 |
+
for i, (role, message) in enumerate(messages):
|
79 |
+
if message:
|
80 |
+
if type(message) is tuple:
|
81 |
+
message, _, _ = message
|
82 |
+
ret += role + ": " + message + seps[i % 2]
|
83 |
+
else:
|
84 |
+
ret += role + ":"
|
85 |
+
elif self.sep_style == SeparatorStyle.QWEN_2:
|
86 |
+
seps = [self.sep, self.sep2]
|
87 |
+
ret = self.system + seps[0]
|
88 |
+
for i, (role, message) in enumerate(messages):
|
89 |
+
if message:
|
90 |
+
if type(message) is tuple:
|
91 |
+
message, _, _ = message
|
92 |
+
ret += role + ": " + message + seps[i % 2]
|
93 |
+
else:
|
94 |
+
ret += role + ":"
|
95 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
96 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
97 |
+
for role, message in messages:
|
98 |
+
if message:
|
99 |
+
if type(message) is tuple:
|
100 |
+
#TODO! NEED to add MM support!
|
101 |
+
message, images = message
|
102 |
+
message = "<image>" * len(images) + message
|
103 |
+
ret += role + "\n" + message + self.sep + "\n"
|
104 |
+
else:
|
105 |
+
ret += role + "\n"
|
106 |
+
return ret
|
107 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
108 |
+
ret = self.system + self.sep
|
109 |
+
for role, message in messages:
|
110 |
+
if message:
|
111 |
+
if type(message) is tuple:
|
112 |
+
message = message[0]
|
113 |
+
ret += role + message + self.sep
|
114 |
+
else:
|
115 |
+
ret += role
|
116 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
117 |
+
ret = self.system + self.sep
|
118 |
+
for role, message in messages:
|
119 |
+
if message:
|
120 |
+
if type(message) is tuple:
|
121 |
+
message, _, _ = message
|
122 |
+
ret += role + message + self.sep
|
123 |
+
else:
|
124 |
+
ret += role
|
125 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
126 |
+
ret = ""
|
127 |
+
for i, (role, message) in enumerate(messages):
|
128 |
+
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
|
129 |
+
if message:
|
130 |
+
if type(message) is tuple:
|
131 |
+
message, _, _ = message
|
132 |
+
ret += role + message + self.sep
|
133 |
+
else:
|
134 |
+
ret += role
|
135 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
|
136 |
+
if self.sep_style == SeparatorStyle.LLAMA_2:
|
137 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
138 |
+
else:
|
139 |
+
wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
|
140 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
141 |
+
ret = ""
|
142 |
+
if self.sep_style == SeparatorStyle.MISTRAL:
|
143 |
+
ret += "<s>"
|
144 |
+
|
145 |
+
for i, (role, message) in enumerate(messages):
|
146 |
+
if i == 0:
|
147 |
+
assert message, "first message should not be none"
|
148 |
+
assert role == self.roles[0], "first message should come from user"
|
149 |
+
if message:
|
150 |
+
if type(message) is tuple:
|
151 |
+
message, _, _ = message
|
152 |
+
if i == 0: message = wrap_sys(self.system) + message
|
153 |
+
if i % 2 == 0:
|
154 |
+
message = wrap_inst(message)
|
155 |
+
ret += self.sep + message
|
156 |
+
else:
|
157 |
+
if self.sep_style == SeparatorStyle.LLAMA_2:
|
158 |
+
ret += " " + message + " " + self.sep2
|
159 |
+
else:
|
160 |
+
ret += message + self.sep2
|
161 |
+
else:
|
162 |
+
ret += ""
|
163 |
+
ret = ret.lstrip(self.sep)
|
164 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
165 |
+
seps = [self.sep, self.sep2]
|
166 |
+
ret = self.system
|
167 |
+
for i, (role, message) in enumerate(messages):
|
168 |
+
if message:
|
169 |
+
if type(message) is tuple:
|
170 |
+
message, _, _ = message
|
171 |
+
ret += message + seps[i % 2]
|
172 |
+
else:
|
173 |
+
ret += ""
|
174 |
+
else:
|
175 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
176 |
+
|
177 |
+
return ret
|
178 |
+
|
179 |
+
def append_message(self, role, message):
|
180 |
+
self.messages.append([role, message])
|
181 |
+
|
182 |
+
def get_images(self, return_pil=False):
|
183 |
+
images = []
|
184 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
185 |
+
if i % 2 == 0:
|
186 |
+
if type(msg) is tuple:
|
187 |
+
import base64
|
188 |
+
from io import BytesIO
|
189 |
+
from PIL import Image
|
190 |
+
msg, image, image_process_mode = msg
|
191 |
+
if image_process_mode == "Pad":
|
192 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
193 |
+
width, height = pil_img.size
|
194 |
+
if width == height:
|
195 |
+
return pil_img
|
196 |
+
elif width > height:
|
197 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
198 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
199 |
+
return result
|
200 |
+
else:
|
201 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
202 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
203 |
+
return result
|
204 |
+
image = expand2square(image)
|
205 |
+
elif image_process_mode in ["Default", "Crop"]:
|
206 |
+
pass
|
207 |
+
elif image_process_mode == "Resize":
|
208 |
+
image = image.resize((336, 336))
|
209 |
+
else:
|
210 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
211 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
212 |
+
aspect_ratio = max_hw / min_hw
|
213 |
+
max_len, min_len = 800, 400
|
214 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
215 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
216 |
+
W, H = image.size
|
217 |
+
if longest_edge != max(image.size):
|
218 |
+
if H > W:
|
219 |
+
H, W = longest_edge, shortest_edge
|
220 |
+
else:
|
221 |
+
H, W = shortest_edge, longest_edge
|
222 |
+
image = image.resize((W, H))
|
223 |
+
if return_pil:
|
224 |
+
images.append(image)
|
225 |
+
else:
|
226 |
+
buffered = BytesIO()
|
227 |
+
image.save(buffered, format="PNG")
|
228 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
229 |
+
images.append(img_b64_str)
|
230 |
+
return images
|
231 |
+
|
232 |
+
def to_gradio_chatbot(self):
|
233 |
+
ret = []
|
234 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
235 |
+
if i % 2 == 0:
|
236 |
+
if type(msg) is tuple:
|
237 |
+
import base64
|
238 |
+
from io import BytesIO
|
239 |
+
msg, image, image_process_mode = msg
|
240 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
241 |
+
aspect_ratio = max_hw / min_hw
|
242 |
+
max_len, min_len = 800, 400
|
243 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
244 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
245 |
+
W, H = image.size
|
246 |
+
if H > W:
|
247 |
+
H, W = longest_edge, shortest_edge
|
248 |
+
else:
|
249 |
+
H, W = shortest_edge, longest_edge
|
250 |
+
image = image.resize((W, H))
|
251 |
+
buffered = BytesIO()
|
252 |
+
image.save(buffered, format="JPEG")
|
253 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
254 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
255 |
+
msg = img_str + msg.replace('<image>', '').strip()
|
256 |
+
ret.append([msg, None])
|
257 |
+
else:
|
258 |
+
ret.append([msg, None])
|
259 |
+
else:
|
260 |
+
ret[-1][-1] = msg
|
261 |
+
return ret
|
262 |
+
|
263 |
+
def copy(self):
|
264 |
+
return Conversation(
|
265 |
+
system=self.system,
|
266 |
+
roles=self.roles,
|
267 |
+
messages=[[x, y] for x, y in self.messages],
|
268 |
+
offset=self.offset,
|
269 |
+
sep_style=self.sep_style,
|
270 |
+
sep=self.sep,
|
271 |
+
sep2=self.sep2,
|
272 |
+
version=self.version)
|
273 |
+
|
274 |
+
def dict(self):
|
275 |
+
if len(self.get_images()) > 0:
|
276 |
+
return {
|
277 |
+
"system": self.system,
|
278 |
+
"roles": self.roles,
|
279 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
280 |
+
"offset": self.offset,
|
281 |
+
"sep": self.sep,
|
282 |
+
"sep2": self.sep2,
|
283 |
+
}
|
284 |
+
return {
|
285 |
+
"system": self.system,
|
286 |
+
"roles": self.roles,
|
287 |
+
"messages": self.messages,
|
288 |
+
"offset": self.offset,
|
289 |
+
"sep": self.sep,
|
290 |
+
"sep2": self.sep2,
|
291 |
+
}
|
292 |
+
|
293 |
+
|
294 |
+
conv_vicuna_v0 = Conversation(
|
295 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
296 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
297 |
+
roles=("Human", "Assistant"),
|
298 |
+
messages=(
|
299 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
300 |
+
("Assistant",
|
301 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
302 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
303 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
304 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
305 |
+
"renewable and non-renewable energy sources:\n"
|
306 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
307 |
+
"energy sources are finite and will eventually run out.\n"
|
308 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
309 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
310 |
+
"and other negative effects.\n"
|
311 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
312 |
+
"have lower operational costs than non-renewable sources.\n"
|
313 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
314 |
+
"locations than non-renewable sources.\n"
|
315 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
316 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
317 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
318 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
319 |
+
),
|
320 |
+
offset=2,
|
321 |
+
sep_style=SeparatorStyle.SINGLE,
|
322 |
+
sep="###",
|
323 |
+
)
|
324 |
+
|
325 |
+
conv_vicuna_v1 = Conversation(
|
326 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
327 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
328 |
+
roles=("USER", "ASSISTANT"),
|
329 |
+
version="v1",
|
330 |
+
messages=(),
|
331 |
+
offset=0,
|
332 |
+
sep_style=SeparatorStyle.TWO,
|
333 |
+
sep=" ",
|
334 |
+
sep2="</s>",
|
335 |
+
)
|
336 |
+
|
337 |
+
# kentang-mit@: This conversation template is designed for SFT on VFLAN.
|
338 |
+
conv_vicuna_v1_nosys = Conversation(
|
339 |
+
system="",
|
340 |
+
roles=("USER", "ASSISTANT"),
|
341 |
+
version="v1_nosys",
|
342 |
+
messages=(),
|
343 |
+
offset=0,
|
344 |
+
sep_style=SeparatorStyle.TWO,
|
345 |
+
sep=" ",
|
346 |
+
sep2="</s>",
|
347 |
+
)
|
348 |
+
|
349 |
+
conv_llama_2 = Conversation(
|
350 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
351 |
+
|
352 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
353 |
+
roles=("USER", "ASSISTANT"),
|
354 |
+
version="llama_v2",
|
355 |
+
messages=(),
|
356 |
+
offset=0,
|
357 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
358 |
+
sep="<s>",
|
359 |
+
sep2="</s>",
|
360 |
+
)
|
361 |
+
|
362 |
+
conv_mistral = Conversation(
|
363 |
+
system="",
|
364 |
+
roles=("USER", "ASSISTANT"),
|
365 |
+
version="mistral",
|
366 |
+
messages=(),
|
367 |
+
offset=0,
|
368 |
+
sep_style=SeparatorStyle.MISTRAL,
|
369 |
+
sep="",
|
370 |
+
sep2="</s>",
|
371 |
+
)
|
372 |
+
|
373 |
+
conv_llava_llama_2 = Conversation(
|
374 |
+
system="You are a helpful language and vision assistant. "
|
375 |
+
"You are able to understand the visual content that the user provides, "
|
376 |
+
"and assist the user with a variety of tasks using natural language.",
|
377 |
+
roles=("USER", "ASSISTANT"),
|
378 |
+
version="llama_v2",
|
379 |
+
messages=(),
|
380 |
+
offset=0,
|
381 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
382 |
+
sep="<s>",
|
383 |
+
sep2="</s>",
|
384 |
+
)
|
385 |
+
|
386 |
+
conv_mpt = Conversation(
|
387 |
+
system="""<|im_start|>system
|
388 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
389 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
390 |
+
version="mpt",
|
391 |
+
messages=(),
|
392 |
+
offset=0,
|
393 |
+
sep_style=SeparatorStyle.MPT,
|
394 |
+
sep="<|im_end|>",
|
395 |
+
)
|
396 |
+
|
397 |
+
conv_plain = Conversation(
|
398 |
+
system="",
|
399 |
+
version="plain",
|
400 |
+
roles=("", ""),
|
401 |
+
messages=[],
|
402 |
+
offset=0,
|
403 |
+
sep_style=SeparatorStyle.PLAIN,
|
404 |
+
sep="\n",
|
405 |
+
)
|
406 |
+
|
407 |
+
conv_llava_v0 = Conversation(
|
408 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
409 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
410 |
+
roles=("Human", "Assistant"),
|
411 |
+
messages=(
|
412 |
+
),
|
413 |
+
offset=0,
|
414 |
+
sep_style=SeparatorStyle.SINGLE,
|
415 |
+
sep="###",
|
416 |
+
)
|
417 |
+
|
418 |
+
conv_llava_v0_mmtag = Conversation(
|
419 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
420 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
421 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
422 |
+
roles=("Human", "Assistant"),
|
423 |
+
messages=(
|
424 |
+
),
|
425 |
+
offset=0,
|
426 |
+
sep_style=SeparatorStyle.SINGLE,
|
427 |
+
sep="###",
|
428 |
+
version="v0_mmtag",
|
429 |
+
)
|
430 |
+
|
431 |
+
conv_llava_v1 = Conversation(
|
432 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
433 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
434 |
+
roles=("USER", "ASSISTANT"),
|
435 |
+
version="v1",
|
436 |
+
messages=(),
|
437 |
+
offset=0,
|
438 |
+
sep_style=SeparatorStyle.TWO,
|
439 |
+
sep=" ",
|
440 |
+
sep2="</s>",
|
441 |
+
)
|
442 |
+
|
443 |
+
|
444 |
+
|
445 |
+
conv_llava_v1_mmtag = Conversation(
|
446 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
447 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
448 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
449 |
+
roles=("USER", "ASSISTANT"),
|
450 |
+
messages=(),
|
451 |
+
offset=0,
|
452 |
+
sep_style=SeparatorStyle.TWO,
|
453 |
+
sep=" ",
|
454 |
+
sep2="</s>",
|
455 |
+
version="v1_mmtag",
|
456 |
+
)
|
457 |
+
|
458 |
+
hermes_2 = Conversation(
|
459 |
+
system='<|im_start|>system\nAnswer the questions.',
|
460 |
+
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
|
461 |
+
sep_style=SeparatorStyle.MPT,
|
462 |
+
sep='<|im_end|>',
|
463 |
+
messages=(
|
464 |
+
),
|
465 |
+
offset=0,
|
466 |
+
version="hermes-2"
|
467 |
+
)
|
468 |
+
|
469 |
+
|
470 |
+
# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
|
471 |
+
llama_3_chat = Conversation(
|
472 |
+
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
|
473 |
+
"You are able to understand the visual content that the user provides, "
|
474 |
+
"and assist the user with a variety of tasks using natural language.",
|
475 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n",
|
476 |
+
"<|start_header_id|>system<|end_header_id|>\n\n"),
|
477 |
+
version="llama_v3",
|
478 |
+
messages=(),
|
479 |
+
offset=0,
|
480 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
481 |
+
sep="<|end_of_text|>",
|
482 |
+
)
|
483 |
+
|
484 |
+
|
485 |
+
conv_qwen = Conversation(
|
486 |
+
system="""<|im_start|>system\n\nYou are a helpful vision-language assistant.""",
|
487 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
488 |
+
version="qwen",
|
489 |
+
messages=[],
|
490 |
+
offset=0,
|
491 |
+
sep_style=SeparatorStyle.CHATML,
|
492 |
+
sep="<|im_end|>",
|
493 |
+
)
|
494 |
+
|
495 |
+
conv_qwen_2 = Conversation(
|
496 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
497 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
498 |
+
roles=("USER", "ASSISTANT"),
|
499 |
+
version="qwen_2",
|
500 |
+
messages=(),
|
501 |
+
offset=0,
|
502 |
+
sep_style=SeparatorStyle.QWEN_2,
|
503 |
+
sep=" ",
|
504 |
+
sep2="<|endoftext|>",
|
505 |
+
)
|
506 |
+
|
507 |
+
conv_gemma_instruct = Conversation(
|
508 |
+
system="",
|
509 |
+
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
|
510 |
+
version="gemma",
|
511 |
+
messages=[],
|
512 |
+
offset=0,
|
513 |
+
sep_style=SeparatorStyle.GEMMA,
|
514 |
+
sep="<end_of_turn>\n"
|
515 |
+
)
|
516 |
+
|
517 |
+
|
518 |
+
default_conversation = conv_plain
|
519 |
+
conv_templates = {
|
520 |
+
"default": conv_plain,
|
521 |
+
"hermes-2": hermes_2,
|
522 |
+
"v0": conv_vicuna_v0,
|
523 |
+
"v1": conv_vicuna_v1,
|
524 |
+
"vicuna_v1": conv_vicuna_v1,
|
525 |
+
"vicuna_v1_nosys": conv_vicuna_v1_nosys,
|
526 |
+
"llama_2": conv_llama_2,
|
527 |
+
"mistral": conv_mistral,
|
528 |
+
"plain": conv_plain,
|
529 |
+
"llava_v0": conv_llava_v0,
|
530 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
531 |
+
"llava_v1": conv_llava_v1,
|
532 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
533 |
+
"llava_llama_2": conv_llava_llama_2,
|
534 |
+
"mpt": conv_mpt,
|
535 |
+
|
536 |
+
"llama_3": llama_3_chat,
|
537 |
+
"qwen_1_5": conv_qwen,
|
538 |
+
"qwen_2": conv_qwen,
|
539 |
+
"gemma_instruct": conv_gemma_instruct,
|
540 |
+
}
|
541 |
+
|
542 |
+
|
543 |
+
if __name__ == "__main__":
|
544 |
+
print(default_conversation.get_prompt())
|
apollo/mm_utils.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# SPDX-License-Identifier: Apache-2.0
|
16 |
+
|
17 |
+
from PIL import Image
|
18 |
+
from io import BytesIO
|
19 |
+
import base64
|
20 |
+
import numpy as np
|
21 |
+
import os, math, cv2, re
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from transformers import StoppingCriteria
|
25 |
+
from apollo.constants import *
|
26 |
+
|
27 |
+
import tempfile
|
28 |
+
from io import BytesIO
|
29 |
+
from decord import VideoReader, cpu
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def read_video_cv2(video_path, all_indices):
|
34 |
+
vidcap = cv2.VideoCapture(video_path)
|
35 |
+
frames_dict = {}
|
36 |
+
max_index = max(all_indices) # Find the maximum index to avoid unnecessary reading
|
37 |
+
count = 0
|
38 |
+
success = True
|
39 |
+
while success and count <= max_index:
|
40 |
+
success, frame = vidcap.read()
|
41 |
+
if success and count in all_indices:
|
42 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
43 |
+
im_pil = Image.fromarray(img)
|
44 |
+
frames_dict[count] = im_pil
|
45 |
+
count += 1
|
46 |
+
# Now retrieve frames according to all_indices, allowing duplicates
|
47 |
+
images = [frames_dict[idx] for idx in all_indices if idx in frames_dict]
|
48 |
+
return np.stack([np.array(img) for img in images])
|
49 |
+
|
50 |
+
def read_video_decord(video_file, all_indices):
|
51 |
+
vr = VideoReader(video_file, num_threads=1, ctx=cpu(0))
|
52 |
+
return vr.get_batch(all_indices).asnumpy()
|
53 |
+
|
54 |
+
|
55 |
+
def read_video_decord_eval(video_file, all_indices):
|
56 |
+
vr = VideoReader(video_file)
|
57 |
+
return vr.get_batch(all_indices).asnumpy()
|
58 |
+
|
59 |
+
def load_frames_from_video(video_file, all_indices, video_decode_backend="decord", eval_=False):
|
60 |
+
video_ending = os.path.splitext(video_file)[1]
|
61 |
+
if video_ending in ['.gif', '.webm'] or video_decode_backend=="opencv":
|
62 |
+
buffer = read_video_cv2(video_file, all_indices)
|
63 |
+
else:
|
64 |
+
# Use decord for other video formats
|
65 |
+
if eval_:
|
66 |
+
buffer = read_video_decord_eval(video_file, all_indices)
|
67 |
+
else:
|
68 |
+
buffer = read_video_decord(video_file, all_indices)
|
69 |
+
return buffer # (T, H, W, C)
|
70 |
+
|
71 |
+
def pad_to_center_square(frames, mean_values):
|
72 |
+
"""
|
73 |
+
Pad the given frame or frames numpy array to square dimensions using the mean values as the padding color.
|
74 |
+
Handles both single frames (H, W, C) and batches of frames (N, H, W, C).
|
75 |
+
|
76 |
+
Args:
|
77 |
+
frames (np.array): The input frame array of shape (H, W, C) or (N, H, W, C).
|
78 |
+
mean_values (tuple): Mean values for each channel, typically derived from dataset normalization parameters.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
np.array: The padded frame array with square dimensions.
|
82 |
+
"""
|
83 |
+
if frames.ndim == 3: # Single frame
|
84 |
+
frames = frames[np.newaxis, :] # Add a batch dimension
|
85 |
+
elif frames.ndim != 4:
|
86 |
+
raise ValueError("Input array must be either of shape (H, W, C) or (N, H, W, C)")
|
87 |
+
|
88 |
+
N, height, width, channels = frames.shape
|
89 |
+
size = max(width, height)
|
90 |
+
background_color = np.array(mean_values, dtype=frames.dtype)
|
91 |
+
|
92 |
+
# Create a background array with the size and fill it with the mean values
|
93 |
+
padded_frames = np.full((N, size, size, channels), background_color, dtype=frames.dtype)
|
94 |
+
|
95 |
+
# Calculate padding offsets
|
96 |
+
top, left = (size - height) // 2, (size - width) // 2
|
97 |
+
|
98 |
+
# Place the original frames in the center of the square canvas
|
99 |
+
padded_frames[:, top:top + height, left:left + width, :] = frames
|
100 |
+
return padded_frames
|
101 |
+
|
102 |
+
|
103 |
+
def expand2square(pil_img, background_color):
|
104 |
+
width, height = pil_img.size
|
105 |
+
if width == height:
|
106 |
+
return pil_img
|
107 |
+
elif width > height:
|
108 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
109 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
110 |
+
# result.paste(pil_img, (0, 0))
|
111 |
+
return result
|
112 |
+
else:
|
113 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
114 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
115 |
+
# result.paste(pil_img, (0, 0))
|
116 |
+
return result
|
117 |
+
|
118 |
+
|
119 |
+
def calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=1):
|
120 |
+
sample_video_fps = frames_per_clip / clip_duration
|
121 |
+
num_clips = math.ceil((video_duration / clip_duration) * clip_sampling_ratio)
|
122 |
+
frame_step = original_fps / sample_video_fps
|
123 |
+
partition_len = total_frames // num_clips
|
124 |
+
all_indices, clip_indices, timestamps = [], [], []
|
125 |
+
if frame_step > 0.5:
|
126 |
+
frame_step = max(1, int(original_fps / sample_video_fps)) #was int/floor
|
127 |
+
clip_len = int(frames_per_clip * frame_step) #was int/floor
|
128 |
+
sample_len = min(clip_len, total_frames)
|
129 |
+
clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
|
130 |
+
for i in range(num_clips):
|
131 |
+
if partition_len > clip_len:
|
132 |
+
start_idx = (partition_len - clip_len) // 2
|
133 |
+
end_idx = start_idx + clip_len
|
134 |
+
indices = np.arange(start_idx, end_idx, frame_step)
|
135 |
+
indices = np.clip(indices, 0, partition_len-1).astype(np.int64)
|
136 |
+
indices = indices+ i * partition_len
|
137 |
+
|
138 |
+
else:
|
139 |
+
|
140 |
+
indices = np.arange(0, sample_len, frame_step)
|
141 |
+
if len(indices) < frames_per_clip:
|
142 |
+
padding = np.full(frames_per_clip - len(indices), sample_len)
|
143 |
+
indices = np.concatenate((indices, padding))
|
144 |
+
|
145 |
+
indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
|
146 |
+
indices = indices + i * clip_step
|
147 |
+
|
148 |
+
clip_indices.append(indices)
|
149 |
+
all_indices.extend(list(indices))
|
150 |
+
|
151 |
+
# Calculate timestamps
|
152 |
+
start_time = (indices[0] / original_fps)
|
153 |
+
end_time = (indices[-1] / original_fps)
|
154 |
+
timestamps.append((start_time, end_time))
|
155 |
+
|
156 |
+
else:
|
157 |
+
## original video FPS too low, we need to sample the same frame multiple times.
|
158 |
+
## Generally should not happen.
|
159 |
+
# Calculate the number of times each frame should be sampled
|
160 |
+
num_sample = int(np.ceil(1 / frame_step))
|
161 |
+
|
162 |
+
# Compute the effective clip length considering the frame step
|
163 |
+
clip_len = int(frames_per_clip * frame_step)
|
164 |
+
|
165 |
+
# Create an expanded list of indices with each frame repeated num_sample times
|
166 |
+
indices = np.repeat(np.arange(clip_len), num_sample)
|
167 |
+
|
168 |
+
# Ensure the clip length does not exceed the total number of frames
|
169 |
+
clip_len = min(clip_len, len(indices))
|
170 |
+
clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
|
171 |
+
|
172 |
+
sample_len = min(clip_len, total_frames)
|
173 |
+
if len(indices) < frames_per_clip:
|
174 |
+
padding = np.full(frames_per_clip - len(indices), sample_len)
|
175 |
+
indices = np.concatenate((indices, padding))
|
176 |
+
|
177 |
+
# Distribute the indices into clips
|
178 |
+
for i in range(num_clips):
|
179 |
+
current_clip_indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
|
180 |
+
current_clip_indices = current_clip_indices + i * clip_step
|
181 |
+
|
182 |
+
# Append the current clip indices to the list of all clips
|
183 |
+
clip_indices.append(current_clip_indices)
|
184 |
+
all_indices.extend(current_clip_indices)
|
185 |
+
|
186 |
+
# Calculate timestamps
|
187 |
+
start_time = (current_clip_indices[0] / original_fps)
|
188 |
+
end_time = (current_clip_indices[-1] / original_fps)
|
189 |
+
timestamps.append((start_time, end_time))
|
190 |
+
|
191 |
+
return clip_indices, all_indices, timestamps
|
192 |
+
|
193 |
+
def calculate_sample_indices_uniform(frames_per_clip, total_frames, uniform_frame_count, original_fps):
|
194 |
+
|
195 |
+
# Generate indices
|
196 |
+
if total_frames >= N:
|
197 |
+
# Sample N frames uniformly without replacement
|
198 |
+
indices = np.linspace(0, total_frames - 1, N, dtype=int)
|
199 |
+
else:
|
200 |
+
# Not enough frames; repeat frames to reach N frames
|
201 |
+
repeats = math.ceil(N / total_frames)
|
202 |
+
base_indices = np.arange(total_frames)
|
203 |
+
indices = np.tile(base_indices, repeats)[:N]
|
204 |
+
|
205 |
+
# Split indices into clips
|
206 |
+
clip_indices = [
|
207 |
+
indices[i * frames_per_clip: (i + 1) * frames_per_clip]
|
208 |
+
for i in range(num_clips)
|
209 |
+
]
|
210 |
+
|
211 |
+
# Calculate timestamps for each clip
|
212 |
+
timestamps = []
|
213 |
+
for clip in clip_indices:
|
214 |
+
start_time = clip[0] / original_fps
|
215 |
+
end_time = clip[-1] / original_fps
|
216 |
+
timestamps.append((start_time, end_time))
|
217 |
+
|
218 |
+
all_indices = indices.tolist()
|
219 |
+
return clip_indices, all_indices, timestamps
|
220 |
+
|
221 |
+
|
222 |
+
def get_video_details(fname):
|
223 |
+
""" Load video content using Decord """
|
224 |
+
assert os.path.exists(fname), f'video path not found {fname}'
|
225 |
+
_fsize = os.path.getsize(fname)
|
226 |
+
assert _fsize >= 1 * 1024, f"video too short {fname}"
|
227 |
+
vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
|
228 |
+
# Get the total number of frames and the original fps of the video
|
229 |
+
total_frames = len(vr)
|
230 |
+
original_fps = vr.get_avg_fps()
|
231 |
+
video_duration = total_frames / original_fps
|
232 |
+
return total_frames, original_fps, video_duration
|
233 |
+
|
234 |
+
|
235 |
+
def get_video_details_cv2(fname):
|
236 |
+
"""
|
237 |
+
Load video content using OpenCV (cv2) and retrieve video details.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
fname (str): Path to the video file.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
tuple: A tuple containing:
|
244 |
+
- total_frames (int): Total number of frames in the video.
|
245 |
+
- original_fps (float): Frames per second of the video.
|
246 |
+
- video_duration (float): Duration of the video in seconds.
|
247 |
+
|
248 |
+
Raises:
|
249 |
+
AssertionError: If the file does not exist or is too short.
|
250 |
+
ValueError: If the video cannot be opened or FPS is zero.
|
251 |
+
"""
|
252 |
+
# Check if the file exists
|
253 |
+
assert os.path.exists(fname), f'Video path not found: {fname}'
|
254 |
+
|
255 |
+
# Check if the file size is at least 1 KB
|
256 |
+
_fsize = os.path.getsize(fname)
|
257 |
+
assert _fsize >= 1 * 1024, f"Video too short: {fname}"
|
258 |
+
|
259 |
+
# Open the video file
|
260 |
+
cap = cv2.VideoCapture(fname)
|
261 |
+
if not cap.isOpened():
|
262 |
+
raise ValueError(f"Failed to open video file: {fname}")
|
263 |
+
|
264 |
+
# Retrieve the total number of frames
|
265 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
266 |
+
|
267 |
+
# Retrieve the frames per second (FPS)
|
268 |
+
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
269 |
+
if original_fps == 0:
|
270 |
+
cap.release()
|
271 |
+
raise ValueError(f"Failed to get FPS for video file: {fname}")
|
272 |
+
|
273 |
+
# Calculate the video duration in seconds
|
274 |
+
video_duration = total_frames / original_fps
|
275 |
+
|
276 |
+
# Release the video capture object
|
277 |
+
cap.release()
|
278 |
+
|
279 |
+
return total_frames, original_fps, video_duration
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
def split_into_clips(video, frames_per_clip):
|
284 |
+
""" Split video into a list of clips """
|
285 |
+
fpc = frames_per_clip
|
286 |
+
nc = len(video) // frames_per_clip
|
287 |
+
return [video[i*fpc:(i+1)*fpc] for i in range(nc)]
|
288 |
+
|
289 |
+
def process_image(vision_processors, frames_per_clip, image):
|
290 |
+
mm_data = []
|
291 |
+
for vision_processor in vision_processors:
|
292 |
+
tmp = expand2square(image, tuple(int(x * 255) for x in vision_processor.image_mean))
|
293 |
+
tmp = np.expand_dims(np.asarray(tmp), 0)
|
294 |
+
tmp = vision_processor.preprocess(tmp, return_tensors='pt')['pixel_values'][0].unsqueeze(0)
|
295 |
+
if len(tmp.shape)==4:
|
296 |
+
## image, need B, T, C, W, H
|
297 |
+
tmp = tmp.unsqueeze(1)
|
298 |
+
tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
|
299 |
+
else:
|
300 |
+
## video, need B, C, T, W, H
|
301 |
+
if tmp.shape[1]==1:
|
302 |
+
tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
|
303 |
+
else:
|
304 |
+
tmp = tmp.repeat_interleave(frames_per_clip, dim=2)
|
305 |
+
|
306 |
+
mm_data.append(tmp)
|
307 |
+
return mm_data
|
308 |
+
|
309 |
+
def process_video(vision_processors, frames_per_clip, buffer):
|
310 |
+
mm_data=[]
|
311 |
+
for vision_processor in vision_processors:
|
312 |
+
centered_buffer = pad_to_center_square(buffer, tuple(int(x * 255) for x in vision_processor.image_mean))
|
313 |
+
processed_clips = []
|
314 |
+
for clip in split_into_clips(centered_buffer, frames_per_clip):
|
315 |
+
clip = vision_processor.preprocess(clip, return_tensors='pt')['pixel_values']
|
316 |
+
if type(clip) is list:
|
317 |
+
assert len(clip)==1, "LazyVideoDataset: error, vision processor returned clip that is list of len>1 ."
|
318 |
+
clip = clip[0]
|
319 |
+
processed_clips.append(clip)
|
320 |
+
mm_data.append(torch.stack(processed_clips))
|
321 |
+
return mm_data
|
322 |
+
|
323 |
+
def load_video(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False):
|
324 |
+
total_frames, original_fps, video_duration = get_video_details(video_file)
|
325 |
+
_, all_indices, timestamps = calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
|
326 |
+
buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
|
327 |
+
mm_data = process_video(vision_processors, frames_per_clip, buffer)
|
328 |
+
return mm_data, timestamps
|
329 |
+
|
330 |
+
def load_video_uniform(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False, uniform_sampling=8):
|
331 |
+
total_frames, original_fps, video_duration = get_video_details(video_file)
|
332 |
+
all_indices = np.linspace(0, total_frames-1, uniform_sampling, dtype=int)
|
333 |
+
print('using uniform frame sampled, sampled: ', len(all_indices), ' frames')
|
334 |
+
buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
|
335 |
+
mm_data = process_video(vision_processors, frames_per_clip, buffer)
|
336 |
+
return mm_data, []
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
class ApolloMMLoader:
|
341 |
+
def __init__(self, vision_processors, clip_duration, frames_per_clip, num_repeat_token, device, model_max_length = 32768, clip_sampling_ratio=1, video_decode_backend="decord"):
|
342 |
+
self.vision_processors=vision_processors
|
343 |
+
self.clip_duration=clip_duration
|
344 |
+
self.device=device
|
345 |
+
self.frames_per_clip=frames_per_clip
|
346 |
+
self.num_repeat_token = num_repeat_token
|
347 |
+
self.clip_sampling_ratio=clip_sampling_ratio
|
348 |
+
self.model_max_length=model_max_length
|
349 |
+
self.video_decode_backend=video_decode_backend
|
350 |
+
self.vidprompt = lambda num_clips, video_duration : f"You are provided the following series of {num2words(num_clips)}, {self.clip_duration} second clips from a {datetime.timedelta(seconds=video_duration)} [H:MM:SS] video.\n"
|
351 |
+
|
352 |
+
def load_video(self, video_file):
|
353 |
+
total_frames, original_fps, video_duration = get_video_details(video_file)
|
354 |
+
clip_sampling_ratio = min(1, (self.model_max_length * self.clip_sampling_ratio) / (video_duration * self.num_repeat_token / self.clip_duration))
|
355 |
+
|
356 |
+
_, all_indices, timestamps = calculate_sample_indices(self.clip_duration, self.frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
|
357 |
+
video, timestamps = load_video(video_file, self.vision_processors, self.clip_duration, self.frames_per_clip, clip_sampling_ratio=clip_sampling_ratio, eval_=True)
|
358 |
+
|
359 |
+
num_clips = len(video[0])
|
360 |
+
num_tokens = num_clips * self.num_repeat_token
|
361 |
+
video = [v.to(device=self.device, dtype=torch.bfloat16) for v in video]
|
362 |
+
replace_string = self.vidprompt(num_clips, video_duration)
|
363 |
+
|
364 |
+
temporal_prompt = [f"{round(clip[0], 1)}-{round(clip[1], 1)} seconds: {X_TOKEN['video'] * self.num_repeat_token}" for clip in timestamps]
|
365 |
+
temporal_prompt = ',\n'.join(temporal_prompt)
|
366 |
+
replace_string = replace_string + temporal_prompt
|
367 |
+
|
368 |
+
return video, replace_string
|
369 |
+
|
370 |
+
def load_image(self, image_file):
|
371 |
+
print('implement image loading')
|
372 |
+
return None
|
373 |
+
|
374 |
+
|
375 |
+
def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None):
|
376 |
+
import cv2
|
377 |
+
|
378 |
+
if fps == None or frame_count == None:
|
379 |
+
# if one of fps or frame_count is None, still recompute
|
380 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
381 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
382 |
+
if fps == 0 or frame_count == 0:
|
383 |
+
print("Video file not found. return empty images.")
|
384 |
+
return [
|
385 |
+
Image.new("RGB", (720, 720)),
|
386 |
+
] * num_frames
|
387 |
+
|
388 |
+
duration = frame_count / fps
|
389 |
+
frame_interval = frame_count // num_frames
|
390 |
+
if frame_interval == 0 and frame_count <= 1:
|
391 |
+
print("frame_interval is equal to 0. return empty image.")
|
392 |
+
return [
|
393 |
+
Image.new("RGB", (720, 720)),
|
394 |
+
] * num_frames
|
395 |
+
# print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
|
396 |
+
|
397 |
+
images = []
|
398 |
+
count = 0
|
399 |
+
success = True
|
400 |
+
frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int)
|
401 |
+
|
402 |
+
while success:
|
403 |
+
# print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
|
404 |
+
if frame_count >= num_frames:
|
405 |
+
success, frame = vidcap.read()
|
406 |
+
if count in frame_indices:
|
407 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
408 |
+
im_pil = Image.fromarray(img)
|
409 |
+
images.append(im_pil)
|
410 |
+
if len(images) >= num_frames:
|
411 |
+
return images
|
412 |
+
count += 1
|
413 |
+
else:
|
414 |
+
# Left padding frames if the video is not long enough
|
415 |
+
success, frame = vidcap.read()
|
416 |
+
if success:
|
417 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
418 |
+
im_pil = Image.fromarray(img)
|
419 |
+
images.append(im_pil)
|
420 |
+
count += 1
|
421 |
+
elif count >= 1:
|
422 |
+
width, height = images[-1].size
|
423 |
+
images = [Image.new("RGB", (width, height))] * (num_frames - len(images)) + images
|
424 |
+
print("padding frames:", (num_frames - len(images)))
|
425 |
+
return images
|
426 |
+
else:
|
427 |
+
break
|
428 |
+
raise ValueError("Did not find enough frames in the video. return empty image.")
|
429 |
+
|
430 |
+
|
431 |
+
def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None):
|
432 |
+
"""
|
433 |
+
Extract frames from a video using OpenCV.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
|
437 |
+
frames (int): Number of frames to extract from the video.
|
438 |
+
|
439 |
+
Returns:
|
440 |
+
list: List of PIL Images extracted from the video.
|
441 |
+
|
442 |
+
Raises:
|
443 |
+
NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
|
444 |
+
"""
|
445 |
+
import cv2
|
446 |
+
|
447 |
+
if isinstance(vpath_or_bytesio, str):
|
448 |
+
vidcap = cv2.VideoCapture(vpath_or_bytesio)
|
449 |
+
return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
|
450 |
+
elif isinstance(vpath_or_bytesio, (BytesIO,)):
|
451 |
+
# assuming mp4
|
452 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
|
453 |
+
temp_video.write(vpath_or_bytesio.read())
|
454 |
+
temp_video_name = temp_video.name
|
455 |
+
vidcap = cv2.VideoCapture(temp_video_name)
|
456 |
+
return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
|
457 |
+
else:
|
458 |
+
raise NotImplementedError(type(vpath_or_bytesio))
|
459 |
+
|
460 |
+
|
461 |
+
def load_image_from_base64(image):
|
462 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
463 |
+
|
464 |
+
|
465 |
+
def expand2square(pil_img, background_color):
|
466 |
+
"""
|
467 |
+
Expand the given PIL image to a square shape by adding padding.
|
468 |
+
|
469 |
+
Parameters:
|
470 |
+
- pil_img: The PIL image to be expanded.
|
471 |
+
- background_color: The color of the padding to be added.
|
472 |
+
|
473 |
+
Returns:
|
474 |
+
- The expanded PIL image.
|
475 |
+
|
476 |
+
If the image is already square, it is returned as is.
|
477 |
+
If the image is wider than it is tall, padding is added to the top and bottom.
|
478 |
+
If the image is taller than it is wide, padding is added to the left and right.
|
479 |
+
"""
|
480 |
+
width, height = pil_img.size
|
481 |
+
if pil_img.mode == 'L':
|
482 |
+
background_color = background_color[0]
|
483 |
+
if width == height:
|
484 |
+
return pil_img
|
485 |
+
elif width > height:
|
486 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
487 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
488 |
+
return result
|
489 |
+
else:
|
490 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
491 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
492 |
+
return result
|
493 |
+
|
494 |
+
|
495 |
+
|
496 |
+
def process_images(images, image_processor, model_cfg):
|
497 |
+
|
498 |
+
model_cfg.image_processor = image_processor
|
499 |
+
new_images = [process_image(image, model_cfg, None) for image in images]
|
500 |
+
|
501 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
502 |
+
new_images = torch.stack(new_images, dim=0)
|
503 |
+
return new_images
|
504 |
+
|
505 |
+
|
506 |
+
|
507 |
+
|
508 |
+
def tokenizer_mm_token(prompt, tokenizer, return_tensors=None):
|
509 |
+
tokens_regex = re.compile('|'.join(re.escape(token) for token in X_TOKEN.values()))
|
510 |
+
input_ids, last_pos, start_id = [], 0, 0
|
511 |
+
for match in tokens_regex.finditer(prompt):
|
512 |
+
if match.start() > last_pos:
|
513 |
+
input_ids.extend(tokenizer(prompt[last_pos:match.start()]).input_ids)
|
514 |
+
elif match.start() == 0:
|
515 |
+
input_ids = tokenizer('').input_ids
|
516 |
+
start_id = 1
|
517 |
+
input_ids.append(X_TOKEN_INDEX)
|
518 |
+
last_pos = match.end()
|
519 |
+
if last_pos < len(prompt):
|
520 |
+
input_ids.extend(tokenizer(prompt[last_pos:]).input_ids[start_id:])
|
521 |
+
return torch.tensor(input_ids, dtype=torch.long) if return_tensors == 'pt' else input_ids
|
522 |
+
|
523 |
+
|
524 |
+
def is_gemma_tokenizer(tokenizer):
|
525 |
+
return "gemma" in tokenizer.__class__.__name__.lower()
|
526 |
+
|
527 |
+
|
528 |
+
def get_model_name_from_path(model_path):
|
529 |
+
model_path = model_path.strip("/")
|
530 |
+
model_paths = model_path.split("/")
|
531 |
+
if model_paths[-1].startswith("checkpoint-"):
|
532 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
533 |
+
else:
|
534 |
+
return model_paths[-1]
|
535 |
+
|
536 |
+
|
537 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
538 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
539 |
+
self.keywords = keywords
|
540 |
+
self.keyword_ids = []
|
541 |
+
self.max_keyword_len = 0
|
542 |
+
for keyword in keywords:
|
543 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
544 |
+
if (
|
545 |
+
len(cur_keyword_ids) > 1
|
546 |
+
and cur_keyword_ids[0] == tokenizer.bos_token_id
|
547 |
+
):
|
548 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
549 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
550 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
551 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
552 |
+
self.tokenizer = tokenizer
|
553 |
+
self.start_len = input_ids.shape[1]
|
554 |
+
|
555 |
+
def call_for_batch(
|
556 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
557 |
+
) -> bool:
|
558 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
559 |
+
self.keyword_ids = [
|
560 |
+
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
|
561 |
+
]
|
562 |
+
for keyword_id in self.keyword_ids:
|
563 |
+
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
564 |
+
return True
|
565 |
+
outputs = self.tokenizer.batch_decode(
|
566 |
+
output_ids[:, -offset:], skip_special_tokens=True
|
567 |
+
)[0]
|
568 |
+
for keyword in self.keywords:
|
569 |
+
if keyword in outputs:
|
570 |
+
return True
|
571 |
+
return False
|
572 |
+
|
573 |
+
def __call__(
|
574 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
575 |
+
) -> bool:
|
576 |
+
outputs = []
|
577 |
+
for i in range(output_ids.shape[0]):
|
578 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
579 |
+
return all(outputs)
|