manysuch-cases commited on
Commit
f23293d
·
verified ·
1 Parent(s): 29c2064

Upload 3 files

Browse files
Files changed (3) hide show
  1. apollo/constants.py +31 -0
  2. apollo/conversation.py +544 -0
  3. apollo/mm_utils.py +579 -0
apollo/constants.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+
19
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
20
+ WORKER_HEART_BEAT_INTERVAL = 15
21
+
22
+ LOGDIR = "."
23
+
24
+
25
+ # Model Constants
26
+ IGNORE_INDEX = -100
27
+ X_TOKEN_INDEX = -200
28
+ X_TOKEN = {'image': "<|image_token|>", 'video': "<|video_token|>"}
29
+ X_PATCH_TOKEN = {'image': "<|image_patch|>", 'video': "<|video_patch|>"}
30
+ X_START_TOKEN = {'image': "<|image_start|>", 'video': "<|video_start|>"}
31
+ X_END_TOKEN = {'image': "<|image_end|>", 'video': "<|video_end|>"}
apollo/conversation.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+
19
+ import dataclasses
20
+ from enum import auto, Enum
21
+ from typing import List, Tuple
22
+
23
+
24
+ class SeparatorStyle(Enum):
25
+ """Different separator style."""
26
+ SINGLE = auto()
27
+ TWO = auto()
28
+ MPT = auto()
29
+ PLAIN = auto()
30
+ LLAMA_2 = auto()
31
+ LLAMA_3 = auto()
32
+ MISTRAL = auto()
33
+ CHATML = auto()
34
+ QWEN = auto()
35
+ QWEN_2 = auto()
36
+ GEMMA = auto()
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class Conversation:
41
+ """A class that keeps all conversation history."""
42
+ system: str
43
+ roles: List[str]
44
+ messages: List[List[str]]
45
+ offset: int
46
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
47
+ sep: str = "###"
48
+ sep2: str = None
49
+ version: str = "Unknown"
50
+
51
+ skip_next: bool = False
52
+
53
+ def get_prompt(self):
54
+ messages = self.messages
55
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
56
+ messages = self.messages.copy()
57
+ init_role, init_msg = messages[0].copy()
58
+ init_msg = init_msg[0].replace("<image>", "").strip()
59
+ if 'mmtag' in self.version:
60
+ messages[0] = (init_role, init_msg)
61
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
62
+ messages.insert(1, (self.roles[1], "Received."))
63
+ else:
64
+ messages[0] = (init_role, "<image>\n" + init_msg)
65
+
66
+ if self.sep_style == SeparatorStyle.SINGLE:
67
+ ret = self.system + self.sep
68
+ for role, message in messages:
69
+ if message:
70
+ if type(message) is tuple:
71
+ message, _, _ = message
72
+ ret += role + ": " + message + self.sep
73
+ else:
74
+ ret += role + ":"
75
+ elif self.sep_style == SeparatorStyle.TWO:
76
+ seps = [self.sep, self.sep2]
77
+ ret = self.system + seps[0]
78
+ for i, (role, message) in enumerate(messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message, _, _ = message
82
+ ret += role + ": " + message + seps[i % 2]
83
+ else:
84
+ ret += role + ":"
85
+ elif self.sep_style == SeparatorStyle.QWEN_2:
86
+ seps = [self.sep, self.sep2]
87
+ ret = self.system + seps[0]
88
+ for i, (role, message) in enumerate(messages):
89
+ if message:
90
+ if type(message) is tuple:
91
+ message, _, _ = message
92
+ ret += role + ": " + message + seps[i % 2]
93
+ else:
94
+ ret += role + ":"
95
+ elif self.sep_style == SeparatorStyle.CHATML:
96
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
97
+ for role, message in messages:
98
+ if message:
99
+ if type(message) is tuple:
100
+ #TODO! NEED to add MM support!
101
+ message, images = message
102
+ message = "<image>" * len(images) + message
103
+ ret += role + "\n" + message + self.sep + "\n"
104
+ else:
105
+ ret += role + "\n"
106
+ return ret
107
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
108
+ ret = self.system + self.sep
109
+ for role, message in messages:
110
+ if message:
111
+ if type(message) is tuple:
112
+ message = message[0]
113
+ ret += role + message + self.sep
114
+ else:
115
+ ret += role
116
+ elif self.sep_style == SeparatorStyle.MPT:
117
+ ret = self.system + self.sep
118
+ for role, message in messages:
119
+ if message:
120
+ if type(message) is tuple:
121
+ message, _, _ = message
122
+ ret += role + message + self.sep
123
+ else:
124
+ ret += role
125
+ elif self.sep_style == SeparatorStyle.GEMMA:
126
+ ret = ""
127
+ for i, (role, message) in enumerate(messages):
128
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
129
+ if message:
130
+ if type(message) is tuple:
131
+ message, _, _ = message
132
+ ret += role + message + self.sep
133
+ else:
134
+ ret += role
135
+ elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
136
+ if self.sep_style == SeparatorStyle.LLAMA_2:
137
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
138
+ else:
139
+ wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
140
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
141
+ ret = ""
142
+ if self.sep_style == SeparatorStyle.MISTRAL:
143
+ ret += "<s>"
144
+
145
+ for i, (role, message) in enumerate(messages):
146
+ if i == 0:
147
+ assert message, "first message should not be none"
148
+ assert role == self.roles[0], "first message should come from user"
149
+ if message:
150
+ if type(message) is tuple:
151
+ message, _, _ = message
152
+ if i == 0: message = wrap_sys(self.system) + message
153
+ if i % 2 == 0:
154
+ message = wrap_inst(message)
155
+ ret += self.sep + message
156
+ else:
157
+ if self.sep_style == SeparatorStyle.LLAMA_2:
158
+ ret += " " + message + " " + self.sep2
159
+ else:
160
+ ret += message + self.sep2
161
+ else:
162
+ ret += ""
163
+ ret = ret.lstrip(self.sep)
164
+ elif self.sep_style == SeparatorStyle.PLAIN:
165
+ seps = [self.sep, self.sep2]
166
+ ret = self.system
167
+ for i, (role, message) in enumerate(messages):
168
+ if message:
169
+ if type(message) is tuple:
170
+ message, _, _ = message
171
+ ret += message + seps[i % 2]
172
+ else:
173
+ ret += ""
174
+ else:
175
+ raise ValueError(f"Invalid style: {self.sep_style}")
176
+
177
+ return ret
178
+
179
+ def append_message(self, role, message):
180
+ self.messages.append([role, message])
181
+
182
+ def get_images(self, return_pil=False):
183
+ images = []
184
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
185
+ if i % 2 == 0:
186
+ if type(msg) is tuple:
187
+ import base64
188
+ from io import BytesIO
189
+ from PIL import Image
190
+ msg, image, image_process_mode = msg
191
+ if image_process_mode == "Pad":
192
+ def expand2square(pil_img, background_color=(122, 116, 104)):
193
+ width, height = pil_img.size
194
+ if width == height:
195
+ return pil_img
196
+ elif width > height:
197
+ result = Image.new(pil_img.mode, (width, width), background_color)
198
+ result.paste(pil_img, (0, (width - height) // 2))
199
+ return result
200
+ else:
201
+ result = Image.new(pil_img.mode, (height, height), background_color)
202
+ result.paste(pil_img, ((height - width) // 2, 0))
203
+ return result
204
+ image = expand2square(image)
205
+ elif image_process_mode in ["Default", "Crop"]:
206
+ pass
207
+ elif image_process_mode == "Resize":
208
+ image = image.resize((336, 336))
209
+ else:
210
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
211
+ max_hw, min_hw = max(image.size), min(image.size)
212
+ aspect_ratio = max_hw / min_hw
213
+ max_len, min_len = 800, 400
214
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
215
+ longest_edge = int(shortest_edge * aspect_ratio)
216
+ W, H = image.size
217
+ if longest_edge != max(image.size):
218
+ if H > W:
219
+ H, W = longest_edge, shortest_edge
220
+ else:
221
+ H, W = shortest_edge, longest_edge
222
+ image = image.resize((W, H))
223
+ if return_pil:
224
+ images.append(image)
225
+ else:
226
+ buffered = BytesIO()
227
+ image.save(buffered, format="PNG")
228
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
229
+ images.append(img_b64_str)
230
+ return images
231
+
232
+ def to_gradio_chatbot(self):
233
+ ret = []
234
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
235
+ if i % 2 == 0:
236
+ if type(msg) is tuple:
237
+ import base64
238
+ from io import BytesIO
239
+ msg, image, image_process_mode = msg
240
+ max_hw, min_hw = max(image.size), min(image.size)
241
+ aspect_ratio = max_hw / min_hw
242
+ max_len, min_len = 800, 400
243
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
244
+ longest_edge = int(shortest_edge * aspect_ratio)
245
+ W, H = image.size
246
+ if H > W:
247
+ H, W = longest_edge, shortest_edge
248
+ else:
249
+ H, W = shortest_edge, longest_edge
250
+ image = image.resize((W, H))
251
+ buffered = BytesIO()
252
+ image.save(buffered, format="JPEG")
253
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
254
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
255
+ msg = img_str + msg.replace('<image>', '').strip()
256
+ ret.append([msg, None])
257
+ else:
258
+ ret.append([msg, None])
259
+ else:
260
+ ret[-1][-1] = msg
261
+ return ret
262
+
263
+ def copy(self):
264
+ return Conversation(
265
+ system=self.system,
266
+ roles=self.roles,
267
+ messages=[[x, y] for x, y in self.messages],
268
+ offset=self.offset,
269
+ sep_style=self.sep_style,
270
+ sep=self.sep,
271
+ sep2=self.sep2,
272
+ version=self.version)
273
+
274
+ def dict(self):
275
+ if len(self.get_images()) > 0:
276
+ return {
277
+ "system": self.system,
278
+ "roles": self.roles,
279
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
280
+ "offset": self.offset,
281
+ "sep": self.sep,
282
+ "sep2": self.sep2,
283
+ }
284
+ return {
285
+ "system": self.system,
286
+ "roles": self.roles,
287
+ "messages": self.messages,
288
+ "offset": self.offset,
289
+ "sep": self.sep,
290
+ "sep2": self.sep2,
291
+ }
292
+
293
+
294
+ conv_vicuna_v0 = Conversation(
295
+ system="A chat between a curious human and an artificial intelligence assistant. "
296
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
297
+ roles=("Human", "Assistant"),
298
+ messages=(
299
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
300
+ ("Assistant",
301
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
302
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
303
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
304
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
305
+ "renewable and non-renewable energy sources:\n"
306
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
307
+ "energy sources are finite and will eventually run out.\n"
308
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
309
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
310
+ "and other negative effects.\n"
311
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
312
+ "have lower operational costs than non-renewable sources.\n"
313
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
314
+ "locations than non-renewable sources.\n"
315
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
316
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
317
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
318
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
319
+ ),
320
+ offset=2,
321
+ sep_style=SeparatorStyle.SINGLE,
322
+ sep="###",
323
+ )
324
+
325
+ conv_vicuna_v1 = Conversation(
326
+ system="A chat between a curious user and an artificial intelligence assistant. "
327
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
328
+ roles=("USER", "ASSISTANT"),
329
+ version="v1",
330
+ messages=(),
331
+ offset=0,
332
+ sep_style=SeparatorStyle.TWO,
333
+ sep=" ",
334
+ sep2="</s>",
335
+ )
336
+
337
+ # kentang-mit@: This conversation template is designed for SFT on VFLAN.
338
+ conv_vicuna_v1_nosys = Conversation(
339
+ system="",
340
+ roles=("USER", "ASSISTANT"),
341
+ version="v1_nosys",
342
+ messages=(),
343
+ offset=0,
344
+ sep_style=SeparatorStyle.TWO,
345
+ sep=" ",
346
+ sep2="</s>",
347
+ )
348
+
349
+ conv_llama_2 = Conversation(
350
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
351
+
352
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
353
+ roles=("USER", "ASSISTANT"),
354
+ version="llama_v2",
355
+ messages=(),
356
+ offset=0,
357
+ sep_style=SeparatorStyle.LLAMA_2,
358
+ sep="<s>",
359
+ sep2="</s>",
360
+ )
361
+
362
+ conv_mistral = Conversation(
363
+ system="",
364
+ roles=("USER", "ASSISTANT"),
365
+ version="mistral",
366
+ messages=(),
367
+ offset=0,
368
+ sep_style=SeparatorStyle.MISTRAL,
369
+ sep="",
370
+ sep2="</s>",
371
+ )
372
+
373
+ conv_llava_llama_2 = Conversation(
374
+ system="You are a helpful language and vision assistant. "
375
+ "You are able to understand the visual content that the user provides, "
376
+ "and assist the user with a variety of tasks using natural language.",
377
+ roles=("USER", "ASSISTANT"),
378
+ version="llama_v2",
379
+ messages=(),
380
+ offset=0,
381
+ sep_style=SeparatorStyle.LLAMA_2,
382
+ sep="<s>",
383
+ sep2="</s>",
384
+ )
385
+
386
+ conv_mpt = Conversation(
387
+ system="""<|im_start|>system
388
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
389
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
390
+ version="mpt",
391
+ messages=(),
392
+ offset=0,
393
+ sep_style=SeparatorStyle.MPT,
394
+ sep="<|im_end|>",
395
+ )
396
+
397
+ conv_plain = Conversation(
398
+ system="",
399
+ version="plain",
400
+ roles=("", ""),
401
+ messages=[],
402
+ offset=0,
403
+ sep_style=SeparatorStyle.PLAIN,
404
+ sep="\n",
405
+ )
406
+
407
+ conv_llava_v0 = Conversation(
408
+ system="A chat between a curious human and an artificial intelligence assistant. "
409
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
410
+ roles=("Human", "Assistant"),
411
+ messages=(
412
+ ),
413
+ offset=0,
414
+ sep_style=SeparatorStyle.SINGLE,
415
+ sep="###",
416
+ )
417
+
418
+ conv_llava_v0_mmtag = Conversation(
419
+ system="A chat between a curious user and an artificial intelligence assistant. "
420
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
421
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
422
+ roles=("Human", "Assistant"),
423
+ messages=(
424
+ ),
425
+ offset=0,
426
+ sep_style=SeparatorStyle.SINGLE,
427
+ sep="###",
428
+ version="v0_mmtag",
429
+ )
430
+
431
+ conv_llava_v1 = Conversation(
432
+ system="A chat between a curious human and an artificial intelligence assistant. "
433
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
434
+ roles=("USER", "ASSISTANT"),
435
+ version="v1",
436
+ messages=(),
437
+ offset=0,
438
+ sep_style=SeparatorStyle.TWO,
439
+ sep=" ",
440
+ sep2="</s>",
441
+ )
442
+
443
+
444
+
445
+ conv_llava_v1_mmtag = Conversation(
446
+ system="A chat between a curious user and an artificial intelligence assistant. "
447
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
448
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
449
+ roles=("USER", "ASSISTANT"),
450
+ messages=(),
451
+ offset=0,
452
+ sep_style=SeparatorStyle.TWO,
453
+ sep=" ",
454
+ sep2="</s>",
455
+ version="v1_mmtag",
456
+ )
457
+
458
+ hermes_2 = Conversation(
459
+ system='<|im_start|>system\nAnswer the questions.',
460
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
461
+ sep_style=SeparatorStyle.MPT,
462
+ sep='<|im_end|>',
463
+ messages=(
464
+ ),
465
+ offset=0,
466
+ version="hermes-2"
467
+ )
468
+
469
+
470
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
471
+ llama_3_chat = Conversation(
472
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
473
+ "You are able to understand the visual content that the user provides, "
474
+ "and assist the user with a variety of tasks using natural language.",
475
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n",
476
+ "<|start_header_id|>system<|end_header_id|>\n\n"),
477
+ version="llama_v3",
478
+ messages=(),
479
+ offset=0,
480
+ sep_style=SeparatorStyle.LLAMA_3,
481
+ sep="<|end_of_text|>",
482
+ )
483
+
484
+
485
+ conv_qwen = Conversation(
486
+ system="""<|im_start|>system\n\nYou are a helpful vision-language assistant.""",
487
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
488
+ version="qwen",
489
+ messages=[],
490
+ offset=0,
491
+ sep_style=SeparatorStyle.CHATML,
492
+ sep="<|im_end|>",
493
+ )
494
+
495
+ conv_qwen_2 = Conversation(
496
+ system="A chat between a curious user and an artificial intelligence assistant. "
497
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
498
+ roles=("USER", "ASSISTANT"),
499
+ version="qwen_2",
500
+ messages=(),
501
+ offset=0,
502
+ sep_style=SeparatorStyle.QWEN_2,
503
+ sep=" ",
504
+ sep2="<|endoftext|>",
505
+ )
506
+
507
+ conv_gemma_instruct = Conversation(
508
+ system="",
509
+ roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
510
+ version="gemma",
511
+ messages=[],
512
+ offset=0,
513
+ sep_style=SeparatorStyle.GEMMA,
514
+ sep="<end_of_turn>\n"
515
+ )
516
+
517
+
518
+ default_conversation = conv_plain
519
+ conv_templates = {
520
+ "default": conv_plain,
521
+ "hermes-2": hermes_2,
522
+ "v0": conv_vicuna_v0,
523
+ "v1": conv_vicuna_v1,
524
+ "vicuna_v1": conv_vicuna_v1,
525
+ "vicuna_v1_nosys": conv_vicuna_v1_nosys,
526
+ "llama_2": conv_llama_2,
527
+ "mistral": conv_mistral,
528
+ "plain": conv_plain,
529
+ "llava_v0": conv_llava_v0,
530
+ "v0_mmtag": conv_llava_v0_mmtag,
531
+ "llava_v1": conv_llava_v1,
532
+ "v1_mmtag": conv_llava_v1_mmtag,
533
+ "llava_llama_2": conv_llava_llama_2,
534
+ "mpt": conv_mpt,
535
+
536
+ "llama_3": llama_3_chat,
537
+ "qwen_1_5": conv_qwen,
538
+ "qwen_2": conv_qwen,
539
+ "gemma_instruct": conv_gemma_instruct,
540
+ }
541
+
542
+
543
+ if __name__ == "__main__":
544
+ print(default_conversation.get_prompt())
apollo/mm_utils.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ from PIL import Image
18
+ from io import BytesIO
19
+ import base64
20
+ import numpy as np
21
+ import os, math, cv2, re
22
+
23
+ import torch
24
+ from transformers import StoppingCriteria
25
+ from apollo.constants import *
26
+
27
+ import tempfile
28
+ from io import BytesIO
29
+ from decord import VideoReader, cpu
30
+
31
+
32
+
33
+ def read_video_cv2(video_path, all_indices):
34
+ vidcap = cv2.VideoCapture(video_path)
35
+ frames_dict = {}
36
+ max_index = max(all_indices) # Find the maximum index to avoid unnecessary reading
37
+ count = 0
38
+ success = True
39
+ while success and count <= max_index:
40
+ success, frame = vidcap.read()
41
+ if success and count in all_indices:
42
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
43
+ im_pil = Image.fromarray(img)
44
+ frames_dict[count] = im_pil
45
+ count += 1
46
+ # Now retrieve frames according to all_indices, allowing duplicates
47
+ images = [frames_dict[idx] for idx in all_indices if idx in frames_dict]
48
+ return np.stack([np.array(img) for img in images])
49
+
50
+ def read_video_decord(video_file, all_indices):
51
+ vr = VideoReader(video_file, num_threads=1, ctx=cpu(0))
52
+ return vr.get_batch(all_indices).asnumpy()
53
+
54
+
55
+ def read_video_decord_eval(video_file, all_indices):
56
+ vr = VideoReader(video_file)
57
+ return vr.get_batch(all_indices).asnumpy()
58
+
59
+ def load_frames_from_video(video_file, all_indices, video_decode_backend="decord", eval_=False):
60
+ video_ending = os.path.splitext(video_file)[1]
61
+ if video_ending in ['.gif', '.webm'] or video_decode_backend=="opencv":
62
+ buffer = read_video_cv2(video_file, all_indices)
63
+ else:
64
+ # Use decord for other video formats
65
+ if eval_:
66
+ buffer = read_video_decord_eval(video_file, all_indices)
67
+ else:
68
+ buffer = read_video_decord(video_file, all_indices)
69
+ return buffer # (T, H, W, C)
70
+
71
+ def pad_to_center_square(frames, mean_values):
72
+ """
73
+ Pad the given frame or frames numpy array to square dimensions using the mean values as the padding color.
74
+ Handles both single frames (H, W, C) and batches of frames (N, H, W, C).
75
+
76
+ Args:
77
+ frames (np.array): The input frame array of shape (H, W, C) or (N, H, W, C).
78
+ mean_values (tuple): Mean values for each channel, typically derived from dataset normalization parameters.
79
+
80
+ Returns:
81
+ np.array: The padded frame array with square dimensions.
82
+ """
83
+ if frames.ndim == 3: # Single frame
84
+ frames = frames[np.newaxis, :] # Add a batch dimension
85
+ elif frames.ndim != 4:
86
+ raise ValueError("Input array must be either of shape (H, W, C) or (N, H, W, C)")
87
+
88
+ N, height, width, channels = frames.shape
89
+ size = max(width, height)
90
+ background_color = np.array(mean_values, dtype=frames.dtype)
91
+
92
+ # Create a background array with the size and fill it with the mean values
93
+ padded_frames = np.full((N, size, size, channels), background_color, dtype=frames.dtype)
94
+
95
+ # Calculate padding offsets
96
+ top, left = (size - height) // 2, (size - width) // 2
97
+
98
+ # Place the original frames in the center of the square canvas
99
+ padded_frames[:, top:top + height, left:left + width, :] = frames
100
+ return padded_frames
101
+
102
+
103
+ def expand2square(pil_img, background_color):
104
+ width, height = pil_img.size
105
+ if width == height:
106
+ return pil_img
107
+ elif width > height:
108
+ result = Image.new(pil_img.mode, (width, width), background_color)
109
+ result.paste(pil_img, (0, (width - height) // 2))
110
+ # result.paste(pil_img, (0, 0))
111
+ return result
112
+ else:
113
+ result = Image.new(pil_img.mode, (height, height), background_color)
114
+ result.paste(pil_img, ((height - width) // 2, 0))
115
+ # result.paste(pil_img, (0, 0))
116
+ return result
117
+
118
+
119
+ def calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=1):
120
+ sample_video_fps = frames_per_clip / clip_duration
121
+ num_clips = math.ceil((video_duration / clip_duration) * clip_sampling_ratio)
122
+ frame_step = original_fps / sample_video_fps
123
+ partition_len = total_frames // num_clips
124
+ all_indices, clip_indices, timestamps = [], [], []
125
+ if frame_step > 0.5:
126
+ frame_step = max(1, int(original_fps / sample_video_fps)) #was int/floor
127
+ clip_len = int(frames_per_clip * frame_step) #was int/floor
128
+ sample_len = min(clip_len, total_frames)
129
+ clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
130
+ for i in range(num_clips):
131
+ if partition_len > clip_len:
132
+ start_idx = (partition_len - clip_len) // 2
133
+ end_idx = start_idx + clip_len
134
+ indices = np.arange(start_idx, end_idx, frame_step)
135
+ indices = np.clip(indices, 0, partition_len-1).astype(np.int64)
136
+ indices = indices+ i * partition_len
137
+
138
+ else:
139
+
140
+ indices = np.arange(0, sample_len, frame_step)
141
+ if len(indices) < frames_per_clip:
142
+ padding = np.full(frames_per_clip - len(indices), sample_len)
143
+ indices = np.concatenate((indices, padding))
144
+
145
+ indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
146
+ indices = indices + i * clip_step
147
+
148
+ clip_indices.append(indices)
149
+ all_indices.extend(list(indices))
150
+
151
+ # Calculate timestamps
152
+ start_time = (indices[0] / original_fps)
153
+ end_time = (indices[-1] / original_fps)
154
+ timestamps.append((start_time, end_time))
155
+
156
+ else:
157
+ ## original video FPS too low, we need to sample the same frame multiple times.
158
+ ## Generally should not happen.
159
+ # Calculate the number of times each frame should be sampled
160
+ num_sample = int(np.ceil(1 / frame_step))
161
+
162
+ # Compute the effective clip length considering the frame step
163
+ clip_len = int(frames_per_clip * frame_step)
164
+
165
+ # Create an expanded list of indices with each frame repeated num_sample times
166
+ indices = np.repeat(np.arange(clip_len), num_sample)
167
+
168
+ # Ensure the clip length does not exceed the total number of frames
169
+ clip_len = min(clip_len, len(indices))
170
+ clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0
171
+
172
+ sample_len = min(clip_len, total_frames)
173
+ if len(indices) < frames_per_clip:
174
+ padding = np.full(frames_per_clip - len(indices), sample_len)
175
+ indices = np.concatenate((indices, padding))
176
+
177
+ # Distribute the indices into clips
178
+ for i in range(num_clips):
179
+ current_clip_indices = np.clip(indices, 0, sample_len-1).astype(np.int64)
180
+ current_clip_indices = current_clip_indices + i * clip_step
181
+
182
+ # Append the current clip indices to the list of all clips
183
+ clip_indices.append(current_clip_indices)
184
+ all_indices.extend(current_clip_indices)
185
+
186
+ # Calculate timestamps
187
+ start_time = (current_clip_indices[0] / original_fps)
188
+ end_time = (current_clip_indices[-1] / original_fps)
189
+ timestamps.append((start_time, end_time))
190
+
191
+ return clip_indices, all_indices, timestamps
192
+
193
+ def calculate_sample_indices_uniform(frames_per_clip, total_frames, uniform_frame_count, original_fps):
194
+
195
+ # Generate indices
196
+ if total_frames >= N:
197
+ # Sample N frames uniformly without replacement
198
+ indices = np.linspace(0, total_frames - 1, N, dtype=int)
199
+ else:
200
+ # Not enough frames; repeat frames to reach N frames
201
+ repeats = math.ceil(N / total_frames)
202
+ base_indices = np.arange(total_frames)
203
+ indices = np.tile(base_indices, repeats)[:N]
204
+
205
+ # Split indices into clips
206
+ clip_indices = [
207
+ indices[i * frames_per_clip: (i + 1) * frames_per_clip]
208
+ for i in range(num_clips)
209
+ ]
210
+
211
+ # Calculate timestamps for each clip
212
+ timestamps = []
213
+ for clip in clip_indices:
214
+ start_time = clip[0] / original_fps
215
+ end_time = clip[-1] / original_fps
216
+ timestamps.append((start_time, end_time))
217
+
218
+ all_indices = indices.tolist()
219
+ return clip_indices, all_indices, timestamps
220
+
221
+
222
+ def get_video_details(fname):
223
+ """ Load video content using Decord """
224
+ assert os.path.exists(fname), f'video path not found {fname}'
225
+ _fsize = os.path.getsize(fname)
226
+ assert _fsize >= 1 * 1024, f"video too short {fname}"
227
+ vr = VideoReader(fname, num_threads=-1, ctx=cpu(0))
228
+ # Get the total number of frames and the original fps of the video
229
+ total_frames = len(vr)
230
+ original_fps = vr.get_avg_fps()
231
+ video_duration = total_frames / original_fps
232
+ return total_frames, original_fps, video_duration
233
+
234
+
235
+ def get_video_details_cv2(fname):
236
+ """
237
+ Load video content using OpenCV (cv2) and retrieve video details.
238
+
239
+ Args:
240
+ fname (str): Path to the video file.
241
+
242
+ Returns:
243
+ tuple: A tuple containing:
244
+ - total_frames (int): Total number of frames in the video.
245
+ - original_fps (float): Frames per second of the video.
246
+ - video_duration (float): Duration of the video in seconds.
247
+
248
+ Raises:
249
+ AssertionError: If the file does not exist or is too short.
250
+ ValueError: If the video cannot be opened or FPS is zero.
251
+ """
252
+ # Check if the file exists
253
+ assert os.path.exists(fname), f'Video path not found: {fname}'
254
+
255
+ # Check if the file size is at least 1 KB
256
+ _fsize = os.path.getsize(fname)
257
+ assert _fsize >= 1 * 1024, f"Video too short: {fname}"
258
+
259
+ # Open the video file
260
+ cap = cv2.VideoCapture(fname)
261
+ if not cap.isOpened():
262
+ raise ValueError(f"Failed to open video file: {fname}")
263
+
264
+ # Retrieve the total number of frames
265
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
266
+
267
+ # Retrieve the frames per second (FPS)
268
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
269
+ if original_fps == 0:
270
+ cap.release()
271
+ raise ValueError(f"Failed to get FPS for video file: {fname}")
272
+
273
+ # Calculate the video duration in seconds
274
+ video_duration = total_frames / original_fps
275
+
276
+ # Release the video capture object
277
+ cap.release()
278
+
279
+ return total_frames, original_fps, video_duration
280
+
281
+
282
+
283
+ def split_into_clips(video, frames_per_clip):
284
+ """ Split video into a list of clips """
285
+ fpc = frames_per_clip
286
+ nc = len(video) // frames_per_clip
287
+ return [video[i*fpc:(i+1)*fpc] for i in range(nc)]
288
+
289
+ def process_image(vision_processors, frames_per_clip, image):
290
+ mm_data = []
291
+ for vision_processor in vision_processors:
292
+ tmp = expand2square(image, tuple(int(x * 255) for x in vision_processor.image_mean))
293
+ tmp = np.expand_dims(np.asarray(tmp), 0)
294
+ tmp = vision_processor.preprocess(tmp, return_tensors='pt')['pixel_values'][0].unsqueeze(0)
295
+ if len(tmp.shape)==4:
296
+ ## image, need B, T, C, W, H
297
+ tmp = tmp.unsqueeze(1)
298
+ tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
299
+ else:
300
+ ## video, need B, C, T, W, H
301
+ if tmp.shape[1]==1:
302
+ tmp = tmp.repeat_interleave(frames_per_clip, dim=1)
303
+ else:
304
+ tmp = tmp.repeat_interleave(frames_per_clip, dim=2)
305
+
306
+ mm_data.append(tmp)
307
+ return mm_data
308
+
309
+ def process_video(vision_processors, frames_per_clip, buffer):
310
+ mm_data=[]
311
+ for vision_processor in vision_processors:
312
+ centered_buffer = pad_to_center_square(buffer, tuple(int(x * 255) for x in vision_processor.image_mean))
313
+ processed_clips = []
314
+ for clip in split_into_clips(centered_buffer, frames_per_clip):
315
+ clip = vision_processor.preprocess(clip, return_tensors='pt')['pixel_values']
316
+ if type(clip) is list:
317
+ assert len(clip)==1, "LazyVideoDataset: error, vision processor returned clip that is list of len>1 ."
318
+ clip = clip[0]
319
+ processed_clips.append(clip)
320
+ mm_data.append(torch.stack(processed_clips))
321
+ return mm_data
322
+
323
+ def load_video(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False):
324
+ total_frames, original_fps, video_duration = get_video_details(video_file)
325
+ _, all_indices, timestamps = calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
326
+ buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
327
+ mm_data = process_video(vision_processors, frames_per_clip, buffer)
328
+ return mm_data, timestamps
329
+
330
+ def load_video_uniform(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False, uniform_sampling=8):
331
+ total_frames, original_fps, video_duration = get_video_details(video_file)
332
+ all_indices = np.linspace(0, total_frames-1, uniform_sampling, dtype=int)
333
+ print('using uniform frame sampled, sampled: ', len(all_indices), ' frames')
334
+ buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_)
335
+ mm_data = process_video(vision_processors, frames_per_clip, buffer)
336
+ return mm_data, []
337
+
338
+
339
+
340
+ class ApolloMMLoader:
341
+ def __init__(self, vision_processors, clip_duration, frames_per_clip, num_repeat_token, device, model_max_length = 32768, clip_sampling_ratio=1, video_decode_backend="decord"):
342
+ self.vision_processors=vision_processors
343
+ self.clip_duration=clip_duration
344
+ self.device=device
345
+ self.frames_per_clip=frames_per_clip
346
+ self.num_repeat_token = num_repeat_token
347
+ self.clip_sampling_ratio=clip_sampling_ratio
348
+ self.model_max_length=model_max_length
349
+ self.video_decode_backend=video_decode_backend
350
+ self.vidprompt = lambda num_clips, video_duration : f"You are provided the following series of {num2words(num_clips)}, {self.clip_duration} second clips from a {datetime.timedelta(seconds=video_duration)} [H:MM:SS] video.\n"
351
+
352
+ def load_video(self, video_file):
353
+ total_frames, original_fps, video_duration = get_video_details(video_file)
354
+ clip_sampling_ratio = min(1, (self.model_max_length * self.clip_sampling_ratio) / (video_duration * self.num_repeat_token / self.clip_duration))
355
+
356
+ _, all_indices, timestamps = calculate_sample_indices(self.clip_duration, self.frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio)
357
+ video, timestamps = load_video(video_file, self.vision_processors, self.clip_duration, self.frames_per_clip, clip_sampling_ratio=clip_sampling_ratio, eval_=True)
358
+
359
+ num_clips = len(video[0])
360
+ num_tokens = num_clips * self.num_repeat_token
361
+ video = [v.to(device=self.device, dtype=torch.bfloat16) for v in video]
362
+ replace_string = self.vidprompt(num_clips, video_duration)
363
+
364
+ temporal_prompt = [f"{round(clip[0], 1)}-{round(clip[1], 1)} seconds: {X_TOKEN['video'] * self.num_repeat_token}" for clip in timestamps]
365
+ temporal_prompt = ',\n'.join(temporal_prompt)
366
+ replace_string = replace_string + temporal_prompt
367
+
368
+ return video, replace_string
369
+
370
+ def load_image(self, image_file):
371
+ print('implement image loading')
372
+ return None
373
+
374
+
375
+ def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None):
376
+ import cv2
377
+
378
+ if fps == None or frame_count == None:
379
+ # if one of fps or frame_count is None, still recompute
380
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
381
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
382
+ if fps == 0 or frame_count == 0:
383
+ print("Video file not found. return empty images.")
384
+ return [
385
+ Image.new("RGB", (720, 720)),
386
+ ] * num_frames
387
+
388
+ duration = frame_count / fps
389
+ frame_interval = frame_count // num_frames
390
+ if frame_interval == 0 and frame_count <= 1:
391
+ print("frame_interval is equal to 0. return empty image.")
392
+ return [
393
+ Image.new("RGB", (720, 720)),
394
+ ] * num_frames
395
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
396
+
397
+ images = []
398
+ count = 0
399
+ success = True
400
+ frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int)
401
+
402
+ while success:
403
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
404
+ if frame_count >= num_frames:
405
+ success, frame = vidcap.read()
406
+ if count in frame_indices:
407
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
408
+ im_pil = Image.fromarray(img)
409
+ images.append(im_pil)
410
+ if len(images) >= num_frames:
411
+ return images
412
+ count += 1
413
+ else:
414
+ # Left padding frames if the video is not long enough
415
+ success, frame = vidcap.read()
416
+ if success:
417
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
418
+ im_pil = Image.fromarray(img)
419
+ images.append(im_pil)
420
+ count += 1
421
+ elif count >= 1:
422
+ width, height = images[-1].size
423
+ images = [Image.new("RGB", (width, height))] * (num_frames - len(images)) + images
424
+ print("padding frames:", (num_frames - len(images)))
425
+ return images
426
+ else:
427
+ break
428
+ raise ValueError("Did not find enough frames in the video. return empty image.")
429
+
430
+
431
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None):
432
+ """
433
+ Extract frames from a video using OpenCV.
434
+
435
+ Args:
436
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
437
+ frames (int): Number of frames to extract from the video.
438
+
439
+ Returns:
440
+ list: List of PIL Images extracted from the video.
441
+
442
+ Raises:
443
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
444
+ """
445
+ import cv2
446
+
447
+ if isinstance(vpath_or_bytesio, str):
448
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
449
+ return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
450
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
451
+ # assuming mp4
452
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
453
+ temp_video.write(vpath_or_bytesio.read())
454
+ temp_video_name = temp_video.name
455
+ vidcap = cv2.VideoCapture(temp_video_name)
456
+ return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
457
+ else:
458
+ raise NotImplementedError(type(vpath_or_bytesio))
459
+
460
+
461
+ def load_image_from_base64(image):
462
+ return Image.open(BytesIO(base64.b64decode(image)))
463
+
464
+
465
+ def expand2square(pil_img, background_color):
466
+ """
467
+ Expand the given PIL image to a square shape by adding padding.
468
+
469
+ Parameters:
470
+ - pil_img: The PIL image to be expanded.
471
+ - background_color: The color of the padding to be added.
472
+
473
+ Returns:
474
+ - The expanded PIL image.
475
+
476
+ If the image is already square, it is returned as is.
477
+ If the image is wider than it is tall, padding is added to the top and bottom.
478
+ If the image is taller than it is wide, padding is added to the left and right.
479
+ """
480
+ width, height = pil_img.size
481
+ if pil_img.mode == 'L':
482
+ background_color = background_color[0]
483
+ if width == height:
484
+ return pil_img
485
+ elif width > height:
486
+ result = Image.new(pil_img.mode, (width, width), background_color)
487
+ result.paste(pil_img, (0, (width - height) // 2))
488
+ return result
489
+ else:
490
+ result = Image.new(pil_img.mode, (height, height), background_color)
491
+ result.paste(pil_img, ((height - width) // 2, 0))
492
+ return result
493
+
494
+
495
+
496
+ def process_images(images, image_processor, model_cfg):
497
+
498
+ model_cfg.image_processor = image_processor
499
+ new_images = [process_image(image, model_cfg, None) for image in images]
500
+
501
+ if all(x.shape == new_images[0].shape for x in new_images):
502
+ new_images = torch.stack(new_images, dim=0)
503
+ return new_images
504
+
505
+
506
+
507
+
508
+ def tokenizer_mm_token(prompt, tokenizer, return_tensors=None):
509
+ tokens_regex = re.compile('|'.join(re.escape(token) for token in X_TOKEN.values()))
510
+ input_ids, last_pos, start_id = [], 0, 0
511
+ for match in tokens_regex.finditer(prompt):
512
+ if match.start() > last_pos:
513
+ input_ids.extend(tokenizer(prompt[last_pos:match.start()]).input_ids)
514
+ elif match.start() == 0:
515
+ input_ids = tokenizer('').input_ids
516
+ start_id = 1
517
+ input_ids.append(X_TOKEN_INDEX)
518
+ last_pos = match.end()
519
+ if last_pos < len(prompt):
520
+ input_ids.extend(tokenizer(prompt[last_pos:]).input_ids[start_id:])
521
+ return torch.tensor(input_ids, dtype=torch.long) if return_tensors == 'pt' else input_ids
522
+
523
+
524
+ def is_gemma_tokenizer(tokenizer):
525
+ return "gemma" in tokenizer.__class__.__name__.lower()
526
+
527
+
528
+ def get_model_name_from_path(model_path):
529
+ model_path = model_path.strip("/")
530
+ model_paths = model_path.split("/")
531
+ if model_paths[-1].startswith("checkpoint-"):
532
+ return model_paths[-2] + "_" + model_paths[-1]
533
+ else:
534
+ return model_paths[-1]
535
+
536
+
537
+ class KeywordsStoppingCriteria(StoppingCriteria):
538
+ def __init__(self, keywords, tokenizer, input_ids):
539
+ self.keywords = keywords
540
+ self.keyword_ids = []
541
+ self.max_keyword_len = 0
542
+ for keyword in keywords:
543
+ cur_keyword_ids = tokenizer(keyword).input_ids
544
+ if (
545
+ len(cur_keyword_ids) > 1
546
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
547
+ ):
548
+ cur_keyword_ids = cur_keyword_ids[1:]
549
+ if len(cur_keyword_ids) > self.max_keyword_len:
550
+ self.max_keyword_len = len(cur_keyword_ids)
551
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
552
+ self.tokenizer = tokenizer
553
+ self.start_len = input_ids.shape[1]
554
+
555
+ def call_for_batch(
556
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
557
+ ) -> bool:
558
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
559
+ self.keyword_ids = [
560
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
561
+ ]
562
+ for keyword_id in self.keyword_ids:
563
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
564
+ return True
565
+ outputs = self.tokenizer.batch_decode(
566
+ output_ids[:, -offset:], skip_special_tokens=True
567
+ )[0]
568
+ for keyword in self.keywords:
569
+ if keyword in outputs:
570
+ return True
571
+ return False
572
+
573
+ def __call__(
574
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
575
+ ) -> bool:
576
+ outputs = []
577
+ for i in range(output_ids.shape[0]):
578
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
579
+ return all(outputs)