zhanghaoji commited on
Commit
eb0678a
1 Parent(s): 378cd97
app.py CHANGED
@@ -1,63 +1,173 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import gradio as gr
3
+ from flash_vstream.serve.demo import Chat, title_markdown, block_css
4
+ from flash_vstream.constants import *
5
+ from flash_vstream.conversation import conv_templates, Conversation
6
+ import os
7
+ from PIL import Image
8
+ import tempfile
9
+ import imageio
10
+ import shutil
11
+
12
+
13
+ model_path = "IVGSZ/Flash-VStream-7b"
14
+ load_8bit = False
15
+ load_4bit = False
16
+
17
+ def save_image_to_local(image):
18
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
19
+ image = Image.open(image)
20
+ image.save(filename)
21
+ return filename
22
+
23
+
24
+ def save_video_to_local(video_path):
25
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
26
+ shutil.copyfile(video_path, filename)
27
+ return filename
28
+
29
+
30
+ def generate(video, textbox_in, first_run, state, state_, images_tensor):
31
+
32
+ flag = 1
33
+ if not textbox_in:
34
+ if len(state_.messages) > 0:
35
+ textbox_in = state_.messages[-1][1]
36
+ state_.messages.pop(-1)
37
+ flag = 0
38
+ else:
39
+ return "Please enter instruction"
40
+
41
+ video = video if video else "none"
42
+
43
+ if type(state) is not Conversation:
44
+ state = conv_templates[conv_mode].copy()
45
+ state_ = conv_templates[conv_mode].copy()
46
+ images_tensor = []
47
+
48
+ first_run = False if len(state.messages) > 0 else True
49
+
50
+ text_en_in = textbox_in.replace("picture", "image")
51
+
52
+ image_processor = handler.image_processor
53
+
54
+ if os.path.exists(video):
55
+ video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
56
+ for img in video_tensor:
57
+ images_tensor.append(image_processor(img, return_tensors='pt')['pixel_values'][0].to(handler.model.device, dtype=torch.float16))
58
+
59
+ if os.path.exists(video):
60
+ text_en_in = DEFAULT_IMAGE_TOKEN * len(video_tensor) + '\n' + text_en_in
61
+
62
+ text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
63
+ state_.messages[-1] = (state_.roles[1], text_en_out)
64
+
65
+ text_en_out = text_en_out.split('#')[0]
66
+ textbox_out = text_en_out
67
+
68
+ show_images = ""
69
+ if os.path.exists(video):
70
+ filename = save_video_to_local(video)
71
+ show_images += f'<video controls playsinline width="500" style="display: inline-block;" src="./file={filename}"></video>'
72
+
73
+ if flag:
74
+ state.append_message(state.roles[0], textbox_in + "\n" + show_images)
75
+ state.append_message(state.roles[1], textbox_out)
76
+
77
+ return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True))
78
+
79
+
80
+ def regenerate(state, state_):
81
+ state.messages.pop(-1)
82
+ state_.messages.pop(-1)
83
+ if len(state.messages) > 0:
84
+ return state, state_, state.to_gradio_chatbot(), False
85
+ return (state, state_, state.to_gradio_chatbot(), True)
86
+
87
+
88
+ def clear_history(state, state_):
89
+ state = conv_templates[conv_mode].copy()
90
+ state_ = conv_templates[conv_mode].copy()
91
+ return (gr.update(value=None, interactive=True), \
92
+ gr.update(value=None, interactive=True),\
93
+ True, state, state_, state.to_gradio_chatbot(), [])
94
+
95
+
96
+ conv_mode = "simple"
97
+ handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit)
98
+ if not os.path.exists("temp"):
99
+ os.makedirs("temp")
100
+
101
+ print(torch.cuda.memory_allocated())
102
+ print(torch.cuda.max_memory_allocated())
103
+
104
+ with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo:
105
+ gr.Markdown(title_markdown)
106
+ state = gr.State()
107
+ state_ = gr.State()
108
+ first_run = gr.State()
109
+ images_tensor = gr.State()
110
+
111
+ with gr.Row():
112
+ with gr.Column(scale=3):
113
+ video = gr.Video(label="Input Video")
114
+
115
+ with gr.Column(scale=7):
116
+ chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700)
117
+ with gr.Row():
118
+ with gr.Column(scale=8):
119
+ textbox = gr.Textbox(show_label=False,
120
+ placeholder="Enter text and press Send",
121
+ container=False)
122
+ with gr.Column(scale=2, min_width=50):
123
+ submit_btn = gr.Button(value="Send", variant="primary", interactive=True)
124
+
125
+ with gr.Row(visible=True) as button_row:
126
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=True)
127
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True)
128
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=True)
129
+
130
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
131
+
132
+ with gr.Row():
133
+ gr.Examples(
134
+ examples=[
135
+ [
136
+ f"{cur_dir}/examples/video2.mp4",
137
+ "Describe the video briefly.",
138
+ ]
139
+ ],
140
+ inputs=[video, textbox],
141
+ )
142
+
143
+ gr.Examples(
144
+ examples=[
145
+ [
146
+ f"{cur_dir}/examples/video4.mp4",
147
+ "What is the boy doing?",
148
+ ]
149
+ ],
150
+ inputs=[video, textbox],
151
+ )
152
+
153
+ gr.Examples(
154
+ examples=[
155
+ [
156
+ f"{cur_dir}/examples/video5.mp4",
157
+ "Why is this video funny?",
158
+ ]
159
+ ],
160
+ inputs=[video, textbox],
161
+ )
162
+
163
+ submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
164
+
165
+ regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
166
+ generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])
167
+
168
+ clear_btn.click(clear_history, [state, state_],
169
+ [video, textbox, first_run, state, state_, chatbot, images_tensor])
170
+
171
+
172
+ # app = gr.mount_gradio_app(app, demo, path="/")
173
+ demo.launch()
app_old.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ response = ""
29
+
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
38
+
39
+ response += token
40
+ yield response
41
+
42
+ """
43
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
+ """
45
+ demo = gr.ChatInterface(
46
+ respond,
47
+ additional_inputs=[
48
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
+ gr.Slider(
52
+ minimum=0.1,
53
+ maximum=1.0,
54
+ value=0.95,
55
+ step=0.05,
56
+ label="Top-p (nucleus sampling)",
57
+ ),
58
+ ],
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ demo.launch()
flash_vstream/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from flash_vstream.model import VStreamLlamaForCausalLM
flash_vstream/constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
4
+ WORKER_HEART_BEAT_INTERVAL = 15
5
+
6
+ LOGDIR = "."
7
+
8
+ # Model Constants
9
+ IGNORE_INDEX = -100
10
+ IMAGE_TOKEN_INDEX = -200
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
12
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
13
+ DEFAULT_IM_START_TOKEN = "<im_start>"
14
+ DEFAULT_IM_END_TOKEN = "<im_end>"
15
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
flash_vstream/conversation.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import dataclasses
4
+ from enum import auto, Enum
5
+ from typing import List, Tuple
6
+
7
+
8
+ class SeparatorStyle(Enum):
9
+ """Different separator style."""
10
+ SINGLE = auto()
11
+ TWO = auto()
12
+ MPT = auto()
13
+ PLAIN = auto()
14
+ LLAMA_2 = auto()
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class Conversation:
19
+ """A class that keeps all conversation history."""
20
+ system: str
21
+ roles: List[str]
22
+ messages: List[List[str]]
23
+ offset: int
24
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
25
+ sep: str = "###"
26
+ sep2: str = None
27
+ version: str = "Unknown"
28
+
29
+ skip_next: bool = False
30
+
31
+ def get_prompt(self):
32
+ messages = self.messages
33
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
34
+ messages = self.messages.copy()
35
+ init_role, init_msg = messages[0].copy()
36
+ init_msg = init_msg[0].replace("<image>", "").strip()
37
+ if 'mmtag' in self.version:
38
+ messages[0] = (init_role, init_msg)
39
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
40
+ messages.insert(1, (self.roles[1], "Received."))
41
+ else:
42
+ messages[0] = (init_role, "<image>\n" + init_msg)
43
+
44
+ if self.sep_style == SeparatorStyle.SINGLE:
45
+ ret = self.system + self.sep
46
+ for role, message in messages:
47
+ if message:
48
+ if type(message) is tuple:
49
+ message, _, _ = message
50
+ ret += role + ": " + message + self.sep
51
+ else:
52
+ ret += role + ":"
53
+ elif self.sep_style == SeparatorStyle.TWO:
54
+ seps = [self.sep, self.sep2]
55
+ ret = self.system + seps[0]
56
+ for i, (role, message) in enumerate(messages):
57
+ if message:
58
+ if type(message) is tuple:
59
+ message, _, _ = message
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ elif self.sep_style == SeparatorStyle.MPT:
64
+ ret = self.system + self.sep
65
+ for role, message in messages:
66
+ if message:
67
+ if type(message) is tuple:
68
+ message, _, _ = message
69
+ ret += role + message + self.sep
70
+ else:
71
+ ret += role
72
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
73
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
74
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
75
+ ret = ""
76
+
77
+ for i, (role, message) in enumerate(messages):
78
+ if i == 0:
79
+ assert message, "first message should not be none"
80
+ assert role == self.roles[0], "first message should come from user"
81
+ if message:
82
+ if type(message) is tuple:
83
+ message, _, _ = message
84
+ if i == 0: message = wrap_sys(self.system) + message
85
+ if i % 2 == 0:
86
+ message = wrap_inst(message)
87
+ ret += self.sep + message
88
+ else:
89
+ ret += " " + message + " " + self.sep2
90
+ else:
91
+ ret += ""
92
+ ret = ret.lstrip(self.sep)
93
+ elif self.sep_style == SeparatorStyle.PLAIN:
94
+ seps = [self.sep, self.sep2]
95
+ ret = self.system
96
+ for i, (role, message) in enumerate(messages):
97
+ if message:
98
+ if type(message) is tuple:
99
+ message, _, _ = message
100
+ ret += message + seps[i % 2]
101
+ else:
102
+ ret += ""
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ return ret
107
+
108
+ def append_message(self, role, message):
109
+ self.messages.append([role, message])
110
+
111
+ def get_images(self, return_pil=False):
112
+ images = []
113
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
114
+ if i % 2 == 0:
115
+ if type(msg) is tuple:
116
+ import base64
117
+ from io import BytesIO
118
+ from PIL import Image
119
+ msg, image, image_process_mode = msg
120
+ if image_process_mode == "Pad":
121
+ def expand2square(pil_img, background_color=(122, 116, 104)):
122
+ width, height = pil_img.size
123
+ if width == height:
124
+ return pil_img
125
+ elif width > height:
126
+ result = Image.new(pil_img.mode, (width, width), background_color)
127
+ result.paste(pil_img, (0, (width - height) // 2))
128
+ return result
129
+ else:
130
+ result = Image.new(pil_img.mode, (height, height), background_color)
131
+ result.paste(pil_img, ((height - width) // 2, 0))
132
+ return result
133
+ image = expand2square(image)
134
+ elif image_process_mode in ["Default", "Crop"]:
135
+ pass
136
+ elif image_process_mode == "Resize":
137
+ image = image.resize((336, 336))
138
+ else:
139
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
140
+ max_hw, min_hw = max(image.size), min(image.size)
141
+ aspect_ratio = max_hw / min_hw
142
+ max_len, min_len = 800, 400
143
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
144
+ longest_edge = int(shortest_edge * aspect_ratio)
145
+ W, H = image.size
146
+ if longest_edge != max(image.size):
147
+ if H > W:
148
+ H, W = longest_edge, shortest_edge
149
+ else:
150
+ H, W = shortest_edge, longest_edge
151
+ image = image.resize((W, H))
152
+ if return_pil:
153
+ images.append(image)
154
+ else:
155
+ buffered = BytesIO()
156
+ image.save(buffered, format="PNG")
157
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
158
+ images.append(img_b64_str)
159
+ return images
160
+
161
+ def to_gradio_chatbot(self):
162
+ ret = []
163
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
164
+ if i % 2 == 0:
165
+ if type(msg) is tuple:
166
+ import base64
167
+ from io import BytesIO
168
+ msg, image, image_process_mode = msg
169
+ max_hw, min_hw = max(image.size), min(image.size)
170
+ aspect_ratio = max_hw / min_hw
171
+ max_len, min_len = 800, 400
172
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
173
+ longest_edge = int(shortest_edge * aspect_ratio)
174
+ W, H = image.size
175
+ if H > W:
176
+ H, W = longest_edge, shortest_edge
177
+ else:
178
+ H, W = shortest_edge, longest_edge
179
+ image = image.resize((W, H))
180
+ buffered = BytesIO()
181
+ image.save(buffered, format="JPEG")
182
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
183
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
184
+ msg = img_str + msg.replace('<image>', '').strip()
185
+ ret.append([msg, None])
186
+ else:
187
+ ret.append([msg, None])
188
+ else:
189
+ ret[-1][-1] = msg
190
+ return ret
191
+
192
+ def copy(self):
193
+ return Conversation(
194
+ system=self.system,
195
+ roles=self.roles,
196
+ messages=[[x, y] for x, y in self.messages],
197
+ offset=self.offset,
198
+ sep_style=self.sep_style,
199
+ sep=self.sep,
200
+ sep2=self.sep2,
201
+ version=self.version)
202
+
203
+ def dict(self):
204
+ if len(self.get_images()) > 0:
205
+ return {
206
+ "system": self.system,
207
+ "roles": self.roles,
208
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
209
+ "offset": self.offset,
210
+ "sep": self.sep,
211
+ "sep2": self.sep2,
212
+ }
213
+ return {
214
+ "system": self.system,
215
+ "roles": self.roles,
216
+ "messages": self.messages,
217
+ "offset": self.offset,
218
+ "sep": self.sep,
219
+ "sep2": self.sep2,
220
+ }
221
+
222
+
223
+ conv_vicuna_v0 = Conversation(
224
+ system="A chat between a curious human and an artificial intelligence assistant. "
225
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
226
+ roles=("Human", "Assistant"),
227
+ messages=(
228
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
229
+ ("Assistant",
230
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
231
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
232
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
233
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
234
+ "renewable and non-renewable energy sources:\n"
235
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
236
+ "energy sources are finite and will eventually run out.\n"
237
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
238
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
239
+ "and other negative effects.\n"
240
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
241
+ "have lower operational costs than non-renewable sources.\n"
242
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
243
+ "locations than non-renewable sources.\n"
244
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
245
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
246
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
247
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
248
+ ),
249
+ offset=2,
250
+ sep_style=SeparatorStyle.SINGLE,
251
+ sep="###",
252
+ )
253
+
254
+ conv_vicuna_v1 = Conversation(
255
+ system="A chat between a curious user and an artificial intelligence assistant. "
256
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
257
+ roles=("USER", "ASSISTANT"),
258
+ version="v1",
259
+ messages=(),
260
+ offset=0,
261
+ sep_style=SeparatorStyle.TWO,
262
+ sep=" ",
263
+ sep2="</s>",
264
+ )
265
+
266
+ conv_vicuna_v1_mcq = Conversation(
267
+ system="A chat between a curious user and an artificial intelligence assistant. "
268
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. "
269
+ "The assistant should give the number of correct answer.",
270
+ roles=("USER", "ASSISTANT"),
271
+ version="v1",
272
+ messages=(),
273
+ offset=0,
274
+ sep_style=SeparatorStyle.TWO,
275
+ sep=" ",
276
+ sep2="</s>",
277
+ )
278
+
279
+ conv_tiny = Conversation(
280
+ system="""<|system|>
281
+ A conversation between a user and an AI assistant. The assistant gives short and honest answers.""",
282
+ roles=("<|user|>\n", "<|assistant|>\n"),
283
+ version="mpt",
284
+ messages=(),
285
+ offset=0,
286
+ sep_style=SeparatorStyle.MPT,
287
+ sep="</s>",
288
+ )
289
+
290
+ conv_llama_2 = Conversation(
291
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
292
+
293
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
294
+ roles=("USER", "ASSISTANT"),
295
+ version="llama_v2",
296
+ messages=(),
297
+ offset=0,
298
+ sep_style=SeparatorStyle.LLAMA_2,
299
+ sep="<s>",
300
+ sep2="</s>",
301
+ )
302
+
303
+ conv_mpt = Conversation(
304
+ system="""<|im_start|>system
305
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
306
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
307
+ version="mpt",
308
+ messages=(),
309
+ offset=0,
310
+ sep_style=SeparatorStyle.MPT,
311
+ sep="<|im_end|>",
312
+ )
313
+
314
+ conv_plain = Conversation(
315
+ system="",
316
+ roles=("", ""),
317
+ messages=(
318
+ ),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.PLAIN,
321
+ sep="\n",
322
+ )
323
+
324
+
325
+ default_conversation = conv_vicuna_v1
326
+ conv_templates = {
327
+ "default": conv_vicuna_v0,
328
+ "v0": conv_vicuna_v0,
329
+ "v1": conv_vicuna_v1,
330
+ "vicuna_v1": conv_vicuna_v1,
331
+ "llama_2": conv_llama_2,
332
+ "plain": conv_plain,
333
+ }
334
+
335
+
336
+ if __name__ == "__main__":
337
+ print(default_conversation.get_prompt())
flash_vstream/eval_video/eval_activitynet_qa.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import os
4
+ import ast
5
+ import json
6
+ import openai
7
+ import argparse
8
+ from tqdm import tqdm
9
+ from time import sleep
10
+ from collections import defaultdict
11
+ from multiprocessing.pool import Pool
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="question-answer-generation-using-gpt-3")
15
+ parser.add_argument("--pred_path", required=True, help="The path to file containing prediction.")
16
+ parser.add_argument("--output_dir", required=True, help="The path to save annotation json files.")
17
+ parser.add_argument("--output_json", required=True, help="The path to save annotation final combined json file.")
18
+ parser.add_argument("--num_tasks", required=True, type=int, help="Number of splits.")
19
+ parser.add_argument("--num_chunks", default=1, type=int, help="Result splits")
20
+ parser.add_argument("--api_key", required=True, type=str, help="OpenAI API key")
21
+ parser.add_argument("--api_type", default=None, type=str, help="OpenAI API type")
22
+ parser.add_argument("--api_version", default=None, type=str, help="OpenAI API version")
23
+ parser.add_argument("--api_base", default=None, type=str, help="OpenAI API base")
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+
28
+ def annotate(prediction_set, caption_files, output_dir):
29
+ """
30
+ Evaluates question and answer pairs using GPT-3
31
+ Returns a score for correctness.
32
+ """
33
+ for file in tqdm(caption_files):
34
+ key = file[:-5] # Strip file extension
35
+ qa_set = prediction_set[key]
36
+ question = qa_set['q']
37
+ answer = qa_set['a']
38
+ pred = qa_set['pred']
39
+ try:
40
+ # Compute the correctness score
41
+ completion = openai.ChatCompletion.create(
42
+ model="gpt-3.5-turbo",
43
+ messages=[
44
+ {
45
+ "role": "system",
46
+ "content":
47
+ "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
48
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
49
+ "------"
50
+ "##INSTRUCTIONS: "
51
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
52
+ "- Consider synonyms or paraphrases as valid matches.\n"
53
+ "- Evaluate the correctness of the prediction compared to the answer."
54
+ },
55
+ {
56
+ "role": "user",
57
+ "content":
58
+ "Please evaluate the following video-based question-answer pair:\n\n"
59
+ f"Question: {question}\n"
60
+ f"Correct Answer: {answer}\n"
61
+ f"Predicted Answer: {pred}\n\n"
62
+ "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
63
+ "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
64
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
65
+ "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}."
66
+ }
67
+ ],
68
+ temperature=0.002
69
+ )
70
+ # Convert response to a Python dictionary.
71
+ response_message = completion["choices"][0]["message"]["content"]
72
+ response_dict = ast.literal_eval(response_message)
73
+ result_qa_pair = [response_dict, qa_set]
74
+
75
+ # Save the question-answer pairs to a json file.
76
+ with open(f"{output_dir}/{key}.json", "w") as f:
77
+ json.dump(result_qa_pair, f)
78
+ sleep(0.5)
79
+
80
+ except Exception as e:
81
+ print(f"Error processing file '{key}': {e}")
82
+ sleep(1)
83
+
84
+
85
+ def main():
86
+ """
87
+ Main function to control the flow of the program.
88
+ """
89
+ # Parse arguments.
90
+ args = parse_args()
91
+
92
+ if args.num_chunks > 1:
93
+ pred_contents = []
94
+ for _idx in range(args.num_chunks):
95
+ file = os.path.join(args.pred_path, f"{args.num_chunks}_{_idx}.json")
96
+ pred_contents += [json.loads(line) for line in open(file)]
97
+
98
+ else:
99
+ file = os.path.join(args.pred_path, f"pred.json")
100
+ pred_contents = [json.loads(line) for line in open(file)]
101
+
102
+ # Dictionary to store the count of occurrences for each video_id
103
+ video_id_counts = {}
104
+ new_pred_contents = []
105
+
106
+ # Iterate through each sample in pred_contents
107
+ for sample in pred_contents:
108
+ video_id = sample['id']
109
+ if video_id in video_id_counts:
110
+ video_id_counts[video_id] += 1
111
+ else:
112
+ video_id_counts[video_id] = 0
113
+
114
+ # Create a new sample with the modified key
115
+ new_sample = sample
116
+ new_sample['id'] = f"{video_id}_{video_id_counts[video_id]}"
117
+ new_pred_contents.append(new_sample)
118
+
119
+ # Generating list of id's and corresponding files
120
+ id_list = [x['id'] for x in new_pred_contents]
121
+ caption_files = [f"{id}.json" for id in id_list]
122
+
123
+ output_dir = args.output_dir
124
+ # Generate output directory if not exists.
125
+ if not os.path.exists(output_dir):
126
+ os.makedirs(output_dir)
127
+
128
+ # Preparing dictionary of question-answer sets
129
+ prediction_set = {}
130
+ for sample in new_pred_contents:
131
+ id = sample['id']
132
+ question = sample['question']
133
+ answer = sample['answer']
134
+ pred = sample['pred']
135
+ qa_set = {"q": question, "a": answer, "pred": pred, "a_type": sample['answer_type'] if 'answer_type' in sample else None}
136
+ prediction_set[id] = qa_set
137
+
138
+ # Set the OpenAI API key.
139
+ openai.api_key = args.api_key # Your API key here
140
+ if args.api_type:
141
+ openai.api_type = args.api_type
142
+ if args.api_version:
143
+ openai.api_version = args.api_version
144
+ if args.api_base:
145
+ openai.api_base = args.api_base # Your API base here
146
+ num_tasks = args.num_tasks
147
+
148
+ # While loop to ensure that all captions are processed.
149
+ incomplete_lengths = []
150
+ for _ in range(100):
151
+ try:
152
+ # Files that have not been processed yet.
153
+ completed_files = os.listdir(output_dir)
154
+ print(f"completed_files: {len(completed_files)}")
155
+
156
+ # Files that have not been processed yet.
157
+ incomplete_files = [f for f in caption_files if f not in completed_files]
158
+ print(f"incomplete_files: {len(incomplete_files)}")
159
+ incomplete_lengths.append(len(incomplete_files))
160
+ if len(incomplete_lengths) > 5 and len(set(incomplete_lengths[-5:])) <= 1:
161
+ print(f"incomplete_lengths: {incomplete_lengths}")
162
+ print(f"incomplete_files: {incomplete_files}")
163
+ print(f"completed_files: {completed_files}")
164
+ print(f"failed for 5 times, break")
165
+ break
166
+
167
+ # Break the loop when there are no incomplete files
168
+ if len(incomplete_files) == 0:
169
+ break
170
+ if len(incomplete_files) <= num_tasks:
171
+ num_tasks = 1
172
+
173
+ # Split tasks into parts.
174
+ part_len = len(incomplete_files) // num_tasks
175
+ all_parts = [incomplete_files[i:i + part_len] for i in range(0, len(incomplete_files), part_len)]
176
+ task_args = [(prediction_set, part, args.output_dir) for part in all_parts]
177
+
178
+ # Use a pool of workers to process the files in parallel.
179
+ with Pool() as pool:
180
+ pool.starmap(annotate, task_args)
181
+
182
+ except Exception as e:
183
+ print(f"Error: {e}")
184
+
185
+ # Combine all the processed files into one
186
+ combined_contents = {}
187
+ json_path = args.output_json
188
+
189
+ # Iterate through json files
190
+ for file_name in os.listdir(output_dir):
191
+ if file_name.endswith(".json"):
192
+ file_path = os.path.join(output_dir, file_name)
193
+ with open(file_path, "r") as json_file:
194
+ content = json.load(json_file)
195
+ assert 'pred' in content[0], f"Error: {file_name} don't has key=pred"
196
+ assert 'score' in content[0], f"Error: {file_name} don't has key=score"
197
+ combined_contents[file_name[:-5]] = content
198
+
199
+ # Write combined content to a json file
200
+ with open(json_path, "w") as json_file:
201
+ json.dump(combined_contents, json_file)
202
+ print("All evaluation completed!")
203
+
204
+ class ScoreMeter:
205
+ def __init__(self):
206
+ self.score_sum = 0
207
+ self.count = 0
208
+ self.yes_count = 0
209
+ self.no_count = 0
210
+ self.score_dict = {'yes': defaultdict(int), 'no': defaultdict(int)}
211
+
212
+ def add_score(self, score, pred):
213
+ self.score_sum += score
214
+ self.count += 1
215
+ pred_lower = pred.lower()
216
+ if 'yes' in pred_lower:
217
+ self.yes_count += 1
218
+ self.score_dict['yes'][score] += 1
219
+ elif 'no' in pred_lower:
220
+ self.no_count += 1
221
+ self.score_dict['no'][score] += 1
222
+
223
+ def get_average_score(self):
224
+ res = (self.score_sum / self.count) if self.count else 0
225
+ return f"{res:.6f}"
226
+
227
+ def get_accuracy(self, response_type):
228
+ if response_type == 'yes':
229
+ res = (self.yes_count / self.count) if self.count else 0
230
+ elif response_type == 'no':
231
+ res = (self.no_count / self.count) if self.count else 0
232
+ else:
233
+ res = 0
234
+ return f"{res:.6f}"
235
+
236
+ meter_dic = {'total': ScoreMeter()}
237
+ for key, result in combined_contents.items():
238
+ # Computing score
239
+ score_match = result[0]['score']
240
+ score = int(score_match)
241
+ pred = result[0]['pred']
242
+
243
+ meter_dic["total"].add_score(score, pred)
244
+ if 'a_type' in result[1] and result[1]['a_type'] is not None:
245
+ typ = str(result[1]['a_type'])
246
+ if typ not in meter_dic:
247
+ meter_dic[typ] = ScoreMeter()
248
+ meter_dic[typ].add_score(score, pred)
249
+
250
+ if 'next' in args.output_dir:
251
+ typ = typ[0]
252
+ if typ not in meter_dic:
253
+ meter_dic[typ] = ScoreMeter()
254
+ meter_dic[typ].add_score(score, pred)
255
+
256
+ csv_dic = {'acc': meter_dic["total"].get_accuracy('yes'), 'score': meter_dic["total"].get_average_score()}
257
+
258
+ output = ""
259
+ output += "Yes count: " + str(meter_dic["total"].yes_count) + "\n"
260
+ output += "No count: " + str(meter_dic["total"].no_count) + "\n"
261
+ output += "Accuracy: " + str(meter_dic["total"].get_accuracy('yes')) + "\n"
262
+ output += "Average score: " + str(meter_dic["total"].get_average_score()) + "\n"
263
+ output += "\n"
264
+ output += "Total Score Yes/No distribution:\n"
265
+ for key, value in meter_dic["total"].score_dict.items():
266
+ output += f"{key}:\n"
267
+ for k in range(0, 6):
268
+ v = value[k]
269
+ output += f"{k}: {v}\n"
270
+ output += "\n"
271
+ output += "Answer Type Score distribution:\n"
272
+ output += 'Type, Accuracy, Avg_score\n'
273
+ key_list = sorted([k for k in meter_dic.keys()])
274
+ for key in key_list:
275
+ output += f"{key}, {meter_dic[key].get_accuracy('yes')}, {meter_dic[key].get_average_score()}\n"
276
+ csv_dic[key] = meter_dic[key].get_accuracy('yes')
277
+
278
+ output += "\n"
279
+ for k in csv_dic.keys():
280
+ output += f"{k}, "
281
+ output = output.rstrip(', ') # Remove the trailing comma and space
282
+ output += "\n"
283
+
284
+ for k in csv_dic.keys():
285
+ output += str(csv_dic[k]) + ", "
286
+ output = output.rstrip(', ') # Remove the trailing comma and space
287
+ output += "\n"
288
+
289
+ print(output)
290
+ args.output_csv = args.output_json.replace(".json", ".csv")
291
+ with open(args.output_csv, 'w') as f:
292
+ f.write(output)
293
+
294
+ if __name__ == "__main__":
295
+ main()
296
+
flash_vstream/eval_video/eval_any_dataset_features.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Flash-VStream Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import argparse
17
+ import subprocess
18
+ import multiprocessing
19
+
20
+ def exec(cmd, sub=False, device=None):
21
+ print(f'exec: {cmd}')
22
+ if not sub:
23
+ if isinstance(cmd, list):
24
+ cmd = ' '.join(cmd)
25
+ os.system(cmd)
26
+ else:
27
+ my_env = os.environ.copy()
28
+ my_env["CUDA_VISIBLE_DEVICES"] = device
29
+ subprocess.run(cmd, env=my_env)
30
+
31
+ # multi gpu, feature
32
+ def eval_msvd(args):
33
+ model_path = args.model_path
34
+ num_chunks = args.num_chunks
35
+ if not args.only_eval:
36
+ processes = []
37
+ for idx in range(0, num_chunks):
38
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
39
+ "--model-path", model_path,
40
+ "--video_dir", "./data/eval_video/MSVD-QA/video_features",
41
+ "--gt_file", "./data/eval_video/MSVD-QA/test_qa.json",
42
+ "--output_dir", os.path.join(model_path, "evaluation", "msvd"),
43
+ "--output_name", "pred",
44
+ "--num-chunks", str(num_chunks),
45
+ "--chunk-idx", str(idx),
46
+ "--conv-mode", "vicuna_v1"]
47
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
48
+ processes.append(p)
49
+ p.start() # 启动子进程
50
+ for p in processes:
51
+ p.join()
52
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
53
+ "--pred_path", os.path.join(model_path, "evaluation", "msvd"),
54
+ "--output_dir", os.path.join(model_path, "evaluation", "msvd", "results"),
55
+ "--output_json", os.path.join(model_path, "evaluation", "msvd", "results.json"),
56
+ "--num_chunks", str(num_chunks),
57
+ "--num_tasks", "16",
58
+ "--api_key", args.api_key,
59
+ "--api_base", args.api_base,
60
+ "--api_type", args.api_type,
61
+ "--api_version", args.api_version,
62
+ ]
63
+ exec(cmd)
64
+
65
+ # multi gpu, feature
66
+ def eval_msrvtt(args):
67
+ model_path = args.model_path
68
+ num_chunks = args.num_chunks
69
+ if not args.only_eval:
70
+ processes = []
71
+ for idx in range(0, num_chunks):
72
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
73
+ "--model-path", model_path,
74
+ "--video_dir", "./data/eval_video/MSRVTT-QA/video_features",
75
+ "--gt_file", "./data/eval_video/MSRVTT-QA/test_qa.json",
76
+ "--output_dir", os.path.join(model_path, "evaluation", "msrvtt"),
77
+ "--output_name", "pred",
78
+ "--num-chunks", str(num_chunks),
79
+ "--chunk-idx", str(idx),
80
+ "--conv-mode", "vicuna_v1"]
81
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
82
+ processes.append(p)
83
+ p.start() # 启动子进程
84
+ for p in processes:
85
+ p.join()
86
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
87
+ "--pred_path", os.path.join(model_path, "evaluation", "msrvtt"),
88
+ "--output_dir", os.path.join(model_path, "evaluation", "msrvtt", "results"),
89
+ "--output_json", os.path.join(model_path, "evaluation", "msrvtt", "results.json"),
90
+ "--num_chunks", str(num_chunks),
91
+ "--num_tasks", "16",
92
+ "--api_key", args.api_key,
93
+ "--api_base", args.api_base,
94
+ "--api_type", args.api_type,
95
+ "--api_version", args.api_version,
96
+ ]
97
+ exec(cmd)
98
+
99
+ # multi gpu, feature
100
+ def eval_actnet(args):
101
+ model_path = args.model_path
102
+ num_chunks = args.num_chunks
103
+ if not args.only_eval:
104
+ processes = []
105
+ for idx in range(0, num_chunks):
106
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
107
+ "--model-path", model_path,
108
+ "--video_dir", "./data/eval_video/ActivityNet-QA/video_features",
109
+ "--gt_file", "./data/eval_video/ActivityNet-QA/test_qa.json",
110
+ "--output_dir", os.path.join(model_path, "evaluation", "actnet"),
111
+ "--output_name", "pred",
112
+ "--num-chunks", str(num_chunks),
113
+ "--chunk-idx", str(idx),
114
+ "--conv-mode", "vicuna_v1",
115
+ ]
116
+
117
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
118
+ processes.append(p)
119
+ p.start() # 启动子进程
120
+ for p in processes:
121
+ p.join()
122
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
123
+ "--pred_path", os.path.join(model_path, "evaluation", "actnet"),
124
+ "--output_dir", os.path.join(model_path, "evaluation", "actnet", "results"),
125
+ "--output_json", os.path.join(model_path, "evaluation", "actnet", "results.json"),
126
+ "--num_chunks", str(num_chunks),
127
+ "--num_tasks", "16",
128
+ "--api_key", args.api_key,
129
+ "--api_base", args.api_base,
130
+ "--api_type", args.api_type,
131
+ "--api_version", args.api_version,
132
+ ]
133
+ exec(cmd)
134
+
135
+ # multi gpu, feature
136
+ def eval_nextoe(args): # follow msvd format, OE follow actnet
137
+ model_path = args.model_path
138
+ num_chunks = args.num_chunks
139
+ if not args.only_eval:
140
+ processes = []
141
+ for idx in range(0, num_chunks):
142
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
143
+ "--model-path", model_path,
144
+ "--video_dir", "./data/eval_video/nextoe/video_features",
145
+ "--gt_file", "./data/eval_video/nextoe/test_qa.json",
146
+ "--output_dir", os.path.join(model_path, "evaluation", "nextoe"),
147
+ "--output_name", "pred",
148
+ "--num-chunks", str(num_chunks),
149
+ "--chunk-idx", str(idx),
150
+ "--conv-mode", "vicuna_v1",
151
+ ]
152
+
153
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
154
+ processes.append(p)
155
+ p.start() # 启动子进程
156
+ for p in processes:
157
+ p.join()
158
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
159
+ "--pred_path", os.path.join(model_path, "evaluation", "nextoe"),
160
+ "--output_dir", os.path.join(model_path, "evaluation", "nextoe", "results"),
161
+ "--output_json", os.path.join(model_path, "evaluation", "nextoe", "results.json"),
162
+ "--num_chunks", str(num_chunks),
163
+ "--num_tasks", "16",
164
+ "--api_key", args.api_key,
165
+ "--api_base", args.api_base,
166
+ "--api_type", args.api_type,
167
+ "--api_version", args.api_version,
168
+ ]
169
+ exec(cmd)
170
+
171
+ # multi gpu, feature
172
+ def eval_vsmovienet(args): # follow msvd format
173
+ model_path = args.model_path
174
+ num_chunks = args.num_chunks
175
+ if not args.only_eval:
176
+ processes = []
177
+ for idx in range(0, num_chunks):
178
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
179
+ "--model-path", model_path,
180
+ "--video_dir", "./data/eval_video/vstream/movienet_video_features",
181
+ "--gt_file", "./data/eval_video/vstream/test_qa_movienet.json",
182
+ "--output_dir", os.path.join(model_path, "evaluation", "vsmovienet"),
183
+ "--output_name", "pred",
184
+ "--num-chunks", str(num_chunks),
185
+ "--chunk-idx", str(idx),
186
+ "--conv-mode", "vicuna_v1",
187
+ ]
188
+
189
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
190
+ processes.append(p)
191
+ p.start() # 启动子进程
192
+ for p in processes:
193
+ p.join()
194
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
195
+ "--pred_path", os.path.join(model_path, "evaluation", "vsmovienet"),
196
+ "--output_dir", os.path.join(model_path, "evaluation", "vsmovienet", "results"),
197
+ "--output_json", os.path.join(model_path, "evaluation", "vsmovienet", "results.json"),
198
+ "--num_chunks", str(num_chunks),
199
+ "--num_tasks", "16",
200
+ "--api_key", args.api_key,
201
+ "--api_base", args.api_base,
202
+ "--api_type", args.api_type,
203
+ "--api_version", args.api_version,
204
+ ]
205
+ exec(cmd)
206
+
207
+ # multi gpu, feature
208
+ def eval_vsego4d(args): # follow msvd format
209
+ model_path = args.model_path
210
+ num_chunks = args.num_chunks
211
+ if not args.only_eval:
212
+ processes = []
213
+ for idx in range(0, num_chunks):
214
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
215
+ "--model-path", model_path,
216
+ "--video_dir", "./data/eval_video/vstream/ego4d_video_features",
217
+ "--gt_file", "./data/eval_video/vstream/test_qa_ego4d.json",
218
+ "--output_dir", os.path.join(model_path, "evaluation", "vsego4d"),
219
+ "--output_name", "pred",
220
+ "--num-chunks", str(num_chunks),
221
+ "--chunk-idx", str(idx),
222
+ "--conv-mode", "vicuna_v1",
223
+ ]
224
+
225
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
226
+ processes.append(p)
227
+ p.start() # 启动子进程
228
+ for p in processes:
229
+ p.join()
230
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
231
+ "--pred_path", os.path.join(model_path, "evaluation", "vsego4d"),
232
+ "--output_dir", os.path.join(model_path, "evaluation", "vsego4d", "results"),
233
+ "--output_json", os.path.join(model_path, "evaluation", "vsego4d", "results.json"),
234
+ "--num_chunks", str(num_chunks),
235
+ "--num_tasks", "16",
236
+ "--api_key", args.api_key,
237
+ "--api_base", args.api_base,
238
+ "--api_type", args.api_type,
239
+ "--api_version", args.api_version,
240
+ ]
241
+ exec(cmd)
242
+
243
+ # multi gpu, feature
244
+ def eval_realtime_vsmovienet(args): # follow msvd format
245
+ model_path = args.model_path
246
+ num_chunks = args.num_chunks
247
+ if not args.only_eval:
248
+ processes = []
249
+ for idx in range(0, num_chunks):
250
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
251
+ "--model-path", model_path,
252
+ "--video_dir", "./data/eval_video/vstream-realtime/movienet_video_features",
253
+ "--gt_file", "./data/eval_video/vstream-realtime/test_qa_movienet.json",
254
+ "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet"),
255
+ "--output_name", "pred",
256
+ "--num-chunks", str(num_chunks),
257
+ "--chunk-idx", str(idx),
258
+ "--conv-mode", "vicuna_v1",
259
+ ]
260
+
261
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
262
+ processes.append(p)
263
+ p.start() # 启动子进程
264
+ for p in processes:
265
+ p.join()
266
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
267
+ "--pred_path", os.path.join(model_path, "evaluation", "realtime_vsmovienet"),
268
+ "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results"),
269
+ "--output_json", os.path.join(model_path, "evaluation", "realtime_vsmovienet", "results.json"),
270
+ "--num_chunks", str(num_chunks),
271
+ "--num_tasks", "16",
272
+ "--api_key", args.api_key,
273
+ "--api_base", args.api_base,
274
+ "--api_type", args.api_type,
275
+ "--api_version", args.api_version,
276
+ ]
277
+ exec(cmd)
278
+
279
+ # multi gpu, feature
280
+ def eval_realtime_vsego4d(args): # follow msvd format
281
+ model_path = args.model_path
282
+ num_chunks = args.num_chunks
283
+ if not args.only_eval:
284
+ processes = []
285
+ for idx in range(0, num_chunks):
286
+ cmd = ["python", "llama_vstream/eval_video/model_msvd_qa_featuresloader.py",
287
+ "--model-path", model_path,
288
+ "--video_dir", "./data/eval_video/vstream-realtime/ego4d_video_features",
289
+ "--gt_file", "./data/eval_video/vstream-realtime/test_qa_ego4d.json",
290
+ "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d"),
291
+ "--output_name", "pred",
292
+ "--num-chunks", str(num_chunks),
293
+ "--chunk-idx", str(idx),
294
+ "--conv-mode", "vicuna_v1",
295
+ ]
296
+
297
+ p = multiprocessing.Process(target=exec, args=(cmd, True, str(idx)))
298
+ processes.append(p)
299
+ p.start() # 启动子进程
300
+ for p in processes:
301
+ p.join()
302
+ cmd = ["python", "llama_vstream/eval_video/eval_activitynet_qa.py",
303
+ "--pred_path", os.path.join(model_path, "evaluation", "realtime_vsego4d"),
304
+ "--output_dir", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results"),
305
+ "--output_json", os.path.join(model_path, "evaluation", "realtime_vsego4d", "results.json"),
306
+ "--num_chunks", str(num_chunks),
307
+ "--num_tasks", "16",
308
+ "--api_key", args.api_key,
309
+ "--api_base", args.api_base,
310
+ "--api_type", args.api_type,
311
+ "--api_version", args.api_version,
312
+ ]
313
+ exec(cmd)
314
+
315
+
316
+ if __name__ == "__main__":
317
+ parser = argparse.ArgumentParser()
318
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
319
+ parser.add_argument("--dataset", type=str, default=None)
320
+ parser.add_argument("--api_key", type=str, default=None)
321
+ parser.add_argument("--api_base", type=str, default=None)
322
+ parser.add_argument("--api_type", type=str, default=None)
323
+ parser.add_argument("--api_version", type=str, default=None)
324
+ parser.add_argument("--num_chunks", type=int, default=1)
325
+ parser.add_argument("--only_eval", action="store_true")
326
+ parser.add_argument("--vizlen", type=int, default=0)
327
+ parser.add_argument("--use_speech", action="store_true", default=False)
328
+ args = parser.parse_args()
329
+ func_dic = {'msvd': eval_msvd,
330
+ 'msrvtt': eval_msrvtt,
331
+ 'actnet': eval_actnet,
332
+ 'nextoe': eval_nextoe,
333
+ 'vsmovienet': eval_vsmovienet,
334
+ 'vsego4d': eval_vsego4d,
335
+ 'realtime_vsmovienet': eval_realtime_vsmovienet,
336
+ 'realtime_vsego4d': eval_realtime_vsego4d,
337
+ }
338
+ if args.dataset in func_dic:
339
+ print(f'Execute {args.dataset} evaluation')
340
+ func_dic[args.dataset](args)
flash_vstream/eval_video/model_msvd_qa.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import os
4
+ import json
5
+ import math
6
+ import torch
7
+ import argparse
8
+ from tqdm import tqdm
9
+ from decord import VideoReader, cpu
10
+
11
+ from llama_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
12
+ from llama_vstream.conversation import conv_templates, SeparatorStyle
13
+ from llama_vstream.model.builder import load_pretrained_model
14
+ from llama_vstream.utils import disable_torch_init
15
+ from llama_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
16
+
17
+
18
+ def split_list(lst, n):
19
+ """Split a list into n (roughly) equal-sized chunks"""
20
+ chunk_size = math.ceil(len(lst) / n) # integer division
21
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22
+
23
+
24
+ def get_chunk(lst, n, k):
25
+ chunks = split_list(lst, n)
26
+ return chunks[k]
27
+
28
+
29
+ def parse_args():
30
+ """
31
+ Parse command-line arguments.
32
+ """
33
+ parser = argparse.ArgumentParser()
34
+
35
+ # Define the command-line arguments
36
+ parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
37
+ parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)
38
+ parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
39
+ parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
40
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
41
+ parser.add_argument("--model-base", type=str, default=None)
42
+ parser.add_argument("--conv-mode", type=str, default=None)
43
+ parser.add_argument("--num-chunks", type=int, default=1)
44
+ parser.add_argument("--chunk-idx", type=int, default=0)
45
+ parser.add_argument("--model-max-length", type=int, default=None)
46
+
47
+ return parser.parse_args()
48
+
49
+
50
+ def load_video(video_path):
51
+ vr = VideoReader(video_path, ctx=cpu(0))
52
+ total_frame_num = len(vr)
53
+ fps = round(vr.get_avg_fps())
54
+ frame_idx = [i for i in range(0, len(vr), fps)]
55
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
56
+ return spare_frames
57
+
58
+
59
+ def run_inference(args):
60
+ """
61
+ Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
62
+
63
+ Args:
64
+ args: Command-line arguments.
65
+ """
66
+ # Initialize the model
67
+ model_name = get_model_name_from_path(args.model_path)
68
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length)
69
+
70
+ # Load both ground truth file containing questions and answers
71
+ with open(args.gt_file) as file:
72
+ gt_questions = json.load(file)
73
+ gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
74
+
75
+ # Create the output directory if it doesn't exist
76
+ if not os.path.exists(args.output_dir):
77
+ try:
78
+ os.makedirs(args.output_dir)
79
+ except Exception as e:
80
+ print(f'mkdir Except: {e}')
81
+
82
+ video_formats = ['.mp4', '.avi', '.mov', '.mkv']
83
+ if args.num_chunks > 1:
84
+ output_name = f"{args.num_chunks}_{args.chunk_idx}"
85
+ else:
86
+ output_name = args.output_name
87
+ answers_file = os.path.join(args.output_dir, f"{output_name}.json")
88
+ ans_file = open(answers_file, "w")
89
+
90
+ for sample in tqdm(gt_questions, desc=f"cuda:{args.chunk_idx} "):
91
+ video_name = sample['video_id']
92
+ question = sample['question']
93
+ id = sample['id']
94
+ answer = sample['answer']
95
+
96
+ sample_set = {'id': id, 'question': question, 'answer': answer}
97
+
98
+ # Load the video file
99
+ for fmt in video_formats: # Added this line
100
+ temp_path = os.path.join(args.video_dir, f"{video_name}{fmt}")
101
+ if os.path.exists(temp_path):
102
+ video_path = temp_path
103
+ break
104
+
105
+ # Check if the video exists
106
+ if os.path.exists(video_path):
107
+ video = load_video(video_path)
108
+ video = image_processor.preprocess(video, return_tensors='pt')['pixel_values'].half().cuda()
109
+ video = [video]
110
+
111
+ qs = question
112
+ if model.config.mm_use_im_start_end:
113
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
114
+ else:
115
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
116
+
117
+ conv = conv_templates[args.conv_mode].copy()
118
+ conv.append_message(conv.roles[0], qs)
119
+ conv.append_message(conv.roles[1], None)
120
+ prompt = conv.get_prompt()
121
+
122
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
123
+
124
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
125
+ keywords = [stop_str]
126
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
127
+
128
+ with torch.inference_mode():
129
+ output_ids = model.generate(
130
+ input_ids,
131
+ images=video,
132
+ do_sample=True,
133
+ temperature=0.002,
134
+ max_new_tokens=1024,
135
+ use_cache=True,
136
+ stopping_criteria=[stopping_criteria])
137
+
138
+ input_token_len = input_ids.shape[1]
139
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
140
+ if n_diff_input_output > 0:
141
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
142
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
143
+ outputs = outputs.strip()
144
+ if outputs.endswith(stop_str):
145
+ outputs = outputs[:-len(stop_str)]
146
+ outputs = outputs.strip()
147
+
148
+ sample_set['pred'] = outputs
149
+ ans_file.write(json.dumps(sample_set) + "\n")
150
+ ans_file.flush()
151
+
152
+ ans_file.close()
153
+
154
+
155
+ if __name__ == "__main__":
156
+ args = parse_args()
157
+ run_inference(args)
flash_vstream/eval_video/model_msvd_qa_featuresloader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # Based on https://github.com/haotian-liu/LLaVA.
3
+
4
+ import os
5
+ import json
6
+ import math
7
+ import torch
8
+ import random
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from safetensors.torch import load_file
13
+
14
+ from llama_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
15
+ from llama_vstream.conversation import conv_templates, SeparatorStyle
16
+ from llama_vstream.model.builder import load_pretrained_model
17
+ from llama_vstream.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
18
+
19
+
20
+ def split_list(lst, n):
21
+ """Split a list into n (roughly) equal-sized chunks"""
22
+ chunk_size = math.ceil(len(lst) / n) # integer division
23
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
24
+
25
+
26
+ def get_chunk(lst, n, k):
27
+ chunks = split_list(lst, n)
28
+ return chunks[k]
29
+
30
+
31
+ def parse_args():
32
+ """
33
+ Parse command-line arguments.
34
+ """
35
+ parser = argparse.ArgumentParser()
36
+
37
+ # Define the command-line arguments
38
+ parser.add_argument('--video_dir', help='Directory containing video files.', required=True)
39
+ parser.add_argument('--gt_file', help='Path to the ground truth file containing question.', required=True)
40
+ parser.add_argument('--output_dir', help='Directory to save the model results JSON.', required=True)
41
+ parser.add_argument('--output_name', help='Name of the file for storing results JSON.', required=True)
42
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
43
+ parser.add_argument("--model-base", type=str, default=None)
44
+ parser.add_argument("--conv-mode", type=str, default=None)
45
+ parser.add_argument("--num-chunks", type=int, default=1)
46
+ parser.add_argument("--chunk-idx", type=int, default=0)
47
+ parser.add_argument("--model-max-length", type=int, default=None)
48
+ return parser.parse_args()
49
+
50
+
51
+ class CustomDataset(Dataset):
52
+ def __init__(self, questions, video_dir, tokenizer, image_processor, model_config):
53
+ self.questions = questions
54
+ self.video_dir = video_dir
55
+ self.tokenizer = tokenizer
56
+ self.image_processor = image_processor
57
+ self.model_config = model_config
58
+
59
+ def __getitem__(self, index):
60
+ sample = self.questions[index]
61
+ video_name = sample['video_id']
62
+ try:
63
+ video_path = os.path.join(self.video_dir, video_name + '.safetensors')
64
+ video_tensor = load_file(video_path)['feature']
65
+ except Exception as e:
66
+ print(f'Dataset Exception: {e}, randomly choose one.')
67
+ idx = random.randint(0, len(self.questions) - 1)
68
+ return self.__getitem__(idx)
69
+ qs = sample['question']
70
+ if self.model_config.mm_use_im_start_end:
71
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
72
+ else:
73
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
74
+ conv = conv_templates[args.conv_mode].copy()
75
+ if 'system' in sample:
76
+ conv.system = conv.system + ' ' + sample['system']
77
+ conv.append_message(conv.roles[0], qs)
78
+ conv.append_message(conv.roles[1], None)
79
+ prompt = conv.get_prompt()
80
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
81
+ return input_ids, video_tensor
82
+
83
+ def __len__(self):
84
+ return len(self.questions)
85
+
86
+
87
+ def create_data_loader(questions, video_dir, tokenizer, image_processor, model_config, batch_size=1, num_workers=2):
88
+ assert batch_size == 1, "batch_size must be 1"
89
+ dataset = CustomDataset(questions, video_dir, tokenizer, image_processor, model_config)
90
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
91
+ return data_loader
92
+
93
+
94
+ def run_inference(args):
95
+ """
96
+ Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
97
+
98
+ Args:
99
+ args: Command-line arguments.
100
+ """
101
+ # Initialize the model
102
+ model_name = get_model_name_from_path(args.model_path)
103
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.model_max_length)
104
+
105
+ # Load both ground truth file containing questions and answers
106
+ with open(args.gt_file) as file:
107
+ gt_questions = json.load(file)
108
+ gt_questions = get_chunk(gt_questions, args.num_chunks, args.chunk_idx)
109
+
110
+ # Create the output directory if it doesn't exist
111
+ if not os.path.exists(args.output_dir):
112
+ try:
113
+ os.makedirs(args.output_dir)
114
+ except Exception as e:
115
+ print(f'mkdir Except: {e}')
116
+
117
+ video_formats = ['.mp4', '.avi', '.mov', '.mkv']
118
+ if args.num_chunks > 1:
119
+ output_name = f"{args.num_chunks}_{args.chunk_idx}"
120
+ else:
121
+ output_name = args.output_name
122
+ answers_file = os.path.join(args.output_dir, f"{output_name}.json")
123
+ # resume from old exp
124
+ exist_id_set = set()
125
+ if os.path.exists(answers_file):
126
+ with open(answers_file) as f:
127
+ exist_pred_contents = [json.loads(line) for line in f]
128
+ exist_id_set = set([x['id'] for x in exist_pred_contents])
129
+
130
+ new_gt_questions = []
131
+ for sample in tqdm(gt_questions):
132
+ if not sample['id'] in exist_id_set:
133
+ new_gt_questions.append(sample)
134
+ gt_questions = new_gt_questions
135
+
136
+ data_loader = create_data_loader(gt_questions, args.video_dir, tokenizer, image_processor, model.config)
137
+
138
+ conv = conv_templates[args.conv_mode].copy()
139
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
140
+ keywords = [stop_str]
141
+
142
+ with open(answers_file, "a") as ans_file:
143
+ for data, sample in tqdm(zip(data_loader, gt_questions), desc=f"cuda:{args.chunk_idx} ", total=len(gt_questions)):
144
+ input_ids, video_tensors = data
145
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
146
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
147
+ with torch.inference_mode():
148
+ output_ids = model.generate(
149
+ input_ids,
150
+ features=video_tensors.to(dtype=torch.float16, device='cuda', non_blocking=True),
151
+ do_sample=True,
152
+ temperature=0.002,
153
+ max_new_tokens=1024,
154
+ use_cache=True,
155
+ stopping_criteria=[stopping_criteria],
156
+ )
157
+ input_token_len = input_ids.shape[1]
158
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
159
+ if n_diff_input_output > 0:
160
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
161
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
162
+ outputs = outputs.strip()
163
+ if outputs.endswith(stop_str):
164
+ outputs = outputs[:-len(stop_str)]
165
+ outputs = outputs.strip()
166
+ sample_set = {
167
+ 'id': sample['id'],
168
+ 'question': sample['question'],
169
+ 'answer': sample['answer'],
170
+ 'answer_type': sample['answer_type'] if 'answer_type' in sample else None,
171
+ 'pred': outputs
172
+ }
173
+ ans_file.write(json.dumps(sample_set) + "\n")
174
+ ans_file.flush()
175
+
176
+
177
+ if __name__ == "__main__":
178
+ args = parse_args()
179
+ run_inference(args)
flash_vstream/mm_utils.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+ from flash_vstream.constants import IMAGE_TOKEN_INDEX
10
+
11
+
12
+ def load_image_from_base64(image):
13
+ return Image.open(BytesIO(base64.b64decode(image)))
14
+
15
+
16
+ def expand2square(pil_img, background_color):
17
+ width, height = pil_img.size
18
+ if width == height:
19
+ return pil_img
20
+ elif width > height:
21
+ result = Image.new(pil_img.mode, (width, width), background_color)
22
+ result.paste(pil_img, (0, (width - height) // 2))
23
+ return result
24
+ else:
25
+ result = Image.new(pil_img.mode, (height, height), background_color)
26
+ result.paste(pil_img, ((height - width) // 2, 0))
27
+ return result
28
+
29
+
30
+ def process_images(images, image_processor, model_cfg):
31
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32
+ new_images = []
33
+ if image_aspect_ratio == 'pad':
34
+ for image in images:
35
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
36
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
37
+ new_images.append(image)
38
+ else:
39
+ return image_processor(images, return_tensors='pt')['pixel_values']
40
+ if all(x.shape == new_images[0].shape for x in new_images):
41
+ new_images = torch.stack(new_images, dim=0)
42
+ return new_images
43
+
44
+
45
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
46
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
47
+
48
+ def insert_separator(X, sep):
49
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
50
+
51
+ input_ids = []
52
+ offset = 0
53
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
54
+ offset = 1
55
+ input_ids.append(prompt_chunks[0][0])
56
+
57
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
58
+ input_ids.extend(x[offset:])
59
+
60
+ if return_tensors is not None:
61
+ if return_tensors == 'pt':
62
+ return torch.tensor(input_ids, dtype=torch.long)
63
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
64
+ return input_ids
65
+
66
+
67
+ def get_model_name_from_path(model_path):
68
+ model_path = model_path.strip("/")
69
+ model_paths = model_path.split("/")
70
+ if model_paths[-1].startswith('checkpoint-'):
71
+ return model_paths[-2] + "_" + model_paths[-1]
72
+ else:
73
+ return model_paths[-1]
74
+
75
+ class KeywordsStoppingCriteria(StoppingCriteria):
76
+ def __init__(self, keywords, tokenizer, input_ids):
77
+ self.keywords = keywords
78
+ self.keyword_ids = []
79
+ self.max_keyword_len = 0
80
+ for keyword in keywords:
81
+ cur_keyword_ids = tokenizer(keyword).input_ids
82
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
83
+ cur_keyword_ids = cur_keyword_ids[1:]
84
+ if len(cur_keyword_ids) > self.max_keyword_len:
85
+ self.max_keyword_len = len(cur_keyword_ids)
86
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
87
+ self.tokenizer = tokenizer
88
+ self.start_len = input_ids.shape[1]
89
+
90
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
91
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
92
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
93
+ for keyword_id in self.keyword_ids:
94
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
95
+ return True
96
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
97
+ for keyword in self.keywords:
98
+ if keyword in outputs:
99
+ return True
100
+ return False
101
+
102
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
103
+ outputs = []
104
+ for i in range(output_ids.shape[0]):
105
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
106
+ return all(outputs)
flash_vstream/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .language_model.vstream_llama import VStreamLlamaForCausalLM, VStreamConfig
flash_vstream/model/builder.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # ------------------------------------------------------------------------
3
+ # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4
+ # Copyright 2023 Haotian Liu
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import os
20
+ import warnings
21
+ import shutil
22
+
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
24
+ import torch
25
+ from flash_vstream.model import VStreamLlamaForCausalLM, VStreamConfig
26
+ from flash_vstream.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
+
28
+
29
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
30
+ kwargs = {"device_map": device_map, **kwargs}
31
+
32
+ if device != "cuda":
33
+ kwargs['device_map'] = {"": device}
34
+
35
+ if load_8bit:
36
+ kwargs['load_in_8bit'] = True
37
+ elif load_4bit:
38
+ kwargs['load_in_4bit'] = True
39
+ kwargs['quantization_config'] = BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_compute_dtype=torch.float16,
42
+ bnb_4bit_use_double_quant=True,
43
+ bnb_4bit_quant_type='nf4'
44
+ )
45
+ else:
46
+ kwargs['torch_dtype'] = torch.float16
47
+
48
+ if 'vstream' in model_name.lower():
49
+ # Load LLaMA-VStream model
50
+ if 'lora' in model_name.lower() and model_base is None:
51
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52
+ if 'lora' in model_name.lower() and model_base is not None:
53
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
54
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
55
+ print('(LoRA) Loading LLaMA-VStream from base model...')
56
+ model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
57
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
58
+ if model.lm_head.weight.shape[0] != token_num:
59
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
60
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61
+
62
+ print('(LoRA) Loading additional LLaMA-VStream weights...')
63
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
64
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
65
+ else:
66
+ # this is probably from HF Hub
67
+ from huggingface_hub import hf_hub_download
68
+ def load_from_hf(repo_id, filename, subfolder=None):
69
+ cache_file = hf_hub_download(
70
+ repo_id=repo_id,
71
+ filename=filename,
72
+ subfolder=subfolder)
73
+ return torch.load(cache_file, map_location='cpu')
74
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
75
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
76
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
77
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
78
+ model.load_state_dict(non_lora_trainables, strict=False)
79
+
80
+ from peft import PeftModel
81
+ print('Loading LoRA weights...')
82
+ model = PeftModel.from_pretrained(model, model_path)
83
+ print('Merging LoRA weights...')
84
+ model = model.merge_and_unload()
85
+ print('Model is loaded...')
86
+ elif model_base is not None:
87
+ # this may be mm projector only
88
+ print('Loading LLaMA-VStream from base model...')
89
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
90
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
91
+ model = VStreamLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
92
+
93
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
94
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
95
+ model.load_state_dict(mm_projector_weights, strict=False)
96
+ else:
97
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
98
+ model = VStreamLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
99
+ else:
100
+ # Load language model
101
+ if model_base is not None:
102
+ # PEFT model
103
+ from peft import PeftModel
104
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
105
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
106
+ print(f"Loading LoRA weights from {model_path}")
107
+ model = PeftModel.from_pretrained(model, model_path)
108
+ print(f"Merging weights")
109
+ model = model.merge_and_unload()
110
+ print('Convert to FP16...')
111
+ model.to(torch.float16)
112
+ else:
113
+ use_fast = False
114
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
115
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
116
+
117
+ image_processor = None
118
+
119
+ if 'vstream' in model_name.lower():
120
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
121
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
122
+ if mm_use_im_patch_token:
123
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
124
+ if mm_use_im_start_end:
125
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
126
+ model.resize_token_embeddings(len(tokenizer))
127
+
128
+ vision_tower = model.get_vision_tower()
129
+ if not vision_tower.is_loaded:
130
+ vision_tower.load_model()
131
+ vision_tower.to(device=device, dtype=torch.float16)
132
+ image_processor = vision_tower.image_processor
133
+
134
+ if hasattr(model.config, "max_sequence_length"):
135
+ context_len = model.config.max_sequence_length
136
+ else:
137
+ context_len = 2048
138
+
139
+ return tokenizer, model, image_processor, context_len
flash_vstream/model/compress_functions.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Flash-VStream Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ def drop_feature(img_feature, video_max_frames, img_similarity=None):
21
+ T, P, D = img_feature.shape
22
+ indices = [[i] for i in range(T)]
23
+ T0 = video_max_frames
24
+ if T <= T0:
25
+ return img_feature, img_similarity, [indices]
26
+ cur_feature = img_feature[:T0] # [T0, P, D]
27
+ if img_similarity is not None:
28
+ cur_sim = img_similarity[:T0 - 1]
29
+ else:
30
+ cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
31
+ cur_indices = indices[:T0]
32
+ step_indices = [cur_indices]
33
+ for i in range(T0, T):
34
+ new_feature = img_feature[i]
35
+ new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
36
+ all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
37
+ all_indices = cur_indices + [[i]]
38
+ all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
39
+ idx = torch.argmax(all_sim)
40
+ if random.randint(0, 1) > 0:
41
+ idx = idx + 1
42
+ cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
43
+ if idx + 1 == T0 + 1:
44
+ cur_sim = all_sim[:T0 - 1]
45
+ cur_indices = all_indices[:-1]
46
+ elif idx == 0:
47
+ cur_sim = all_sim[1:]
48
+ cur_indices = all_indices[1:]
49
+ else:
50
+ cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
51
+ cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
52
+ cur_indices = all_indices[:idx] + all_indices[idx + 1:]
53
+ step_indices.append(cur_indices)
54
+ # print(f'Note: perform drop feature {img_feature.shape} to {cur_feature.shape}')
55
+ return cur_feature, cur_sim, step_indices
56
+
57
+
58
+ def merge_feature(img_feature, video_max_frames, img_similarity=None):
59
+ T, P, D = img_feature.shape
60
+ indices = [[i] for i in range(T)]
61
+ T0 = video_max_frames
62
+ if T <= T0:
63
+ return img_feature, img_similarity, [indices]
64
+ cur_feature = img_feature[:T0] # [T0, P, D]
65
+ cur_indices = indices[:T0]
66
+ step_indices = [cur_indices]
67
+ if img_similarity is not None:
68
+ cur_sim = img_similarity[:T0 - 1]
69
+ else:
70
+ cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
71
+ for i in range(T0, T):
72
+ new_feature = img_feature[i]
73
+ new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
74
+ all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
75
+ all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
76
+ all_indices = cur_indices + [[i]]
77
+ idx = torch.argmax(all_sim)
78
+ all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0
79
+ all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1]
80
+ cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
81
+ cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
82
+ cur_indices = all_indices[:idx] + all_indices[idx + 1:]
83
+ if idx > 0:
84
+ cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
85
+ if idx + 1 < T0:
86
+ cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0)
87
+ step_indices.append(cur_indices)
88
+ # print(f'Note: perform merge feature {img_feature.shape} to {cur_feature.shape}')
89
+ return cur_feature, cur_sim, step_indices
90
+
91
+
92
+ def kmeans_feature(img_feature, video_max_frames, img_similarity=None):
93
+ def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10):
94
+ indices = torch.randperm(X.size(0))[:num_clusters]
95
+ centroids = X[indices]
96
+ for i in range(max_iter):
97
+ if distance == 'euclidean':
98
+ dists = torch.cdist(X, centroids, p=2)
99
+ else:
100
+ raise NotImplementedError("Only Euclidean distance is supported yet")
101
+ labels = torch.argmin(dists, dim=1)
102
+ new_centroids = []
103
+ for j in range(num_clusters):
104
+ cluster_points = X[labels == j]
105
+ if len(cluster_points) > 0:
106
+ new_centroid = cluster_points.mean(0)
107
+ else: # fix nan centroids
108
+ new_centroid = X[random.randint(0, X.size(0) - 1)]
109
+ new_centroids.append(new_centroid)
110
+ new_centroids = torch.stack(new_centroids)
111
+ diff = torch.norm(centroids - new_centroids, dim=1).sum()
112
+ if diff < tol:
113
+ break
114
+ centroids = new_centroids
115
+ return centroids, labels, i
116
+ T, P, D = img_feature.shape
117
+ T0 = video_max_frames
118
+ if T <= T0:
119
+ return img_feature, img_similarity, [[[i] for i in range(T)]]
120
+ X = img_feature.view(T, -1) # [T, P, D]
121
+ centroids, labels, exit_step = kmeans_torch(X, T0)
122
+ reduced_feature = centroids.view(T0, P, D)
123
+ # print(f'Note: perform kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
124
+ step_indices = [[] for _ in range(T0)]
125
+ for i in range(T0):
126
+ step_indices[i] = [j for j in range(T) if labels[j] == i]
127
+ return reduced_feature, img_similarity, [step_indices]
128
+
129
+
130
+ def weighted_kmeans_feature(img_feature, video_max_frames, weights=None):
131
+ if weights is None:
132
+ weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device)
133
+ def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10):
134
+ indices = torch.randperm(X.size(0), device=X.device)[:num_clusters]
135
+ centroids = X[indices]
136
+ for i in range(max_iter):
137
+ if distance == 'euclidean':
138
+ dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt()
139
+ else:
140
+ raise NotImplementedError("Only Euclidean distance is supported yet")
141
+ labels = torch.argmin(dists, dim=1)
142
+ weighted_sum = torch.zeros_like(centroids)
143
+ weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device)
144
+ for j in range(num_clusters):
145
+ cluster_mask = labels == j
146
+ weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0)
147
+ weights_sum[j] = torch.sum(weights[cluster_mask])
148
+ mask = weights_sum > 0
149
+ new_centroids = torch.zeros_like(weighted_sum)
150
+ new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None]
151
+ if mask.sum() < num_clusters: # fix nan centroids
152
+ new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())])
153
+ diff = torch.norm(centroids - new_centroids, dim=1).sum()
154
+ if diff < tol:
155
+ break
156
+ centroids = new_centroids
157
+ return centroids, labels, weights_sum, i
158
+ T, P, D = img_feature.shape
159
+ T0 = video_max_frames
160
+ if T <= T0:
161
+ return img_feature, weights, [[[i] for i in range(T)]]
162
+ X = img_feature.view(T, -1) # [T, P, D]
163
+ centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights)
164
+ reduced_feature = centroids.view(T0, P, D)
165
+ # print(f'Note: perform weighted kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
166
+ step_indices = [[] for _ in range(T0)]
167
+ for i in range(T0):
168
+ step_indices[i] = [j for j in range(T) if labels[j] == i]
169
+ return reduced_feature, weights, [step_indices]
170
+
171
+
172
+ def k_drop_feature(img_feature, video_max_frames, img_similarity=None):
173
+ T, P, D = img_feature.shape
174
+ indices = [[i] for i in range(T)]
175
+ T0 = video_max_frames
176
+ if T <= T0:
177
+ return img_feature, img_similarity, [indices]
178
+ cur_feature = img_feature[:T0] # [T0, P, D]
179
+ normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
180
+ cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
181
+ cur_sim.fill_diagonal_(-100.0)
182
+ cur_indices = indices[:T0]
183
+ step_indices = [cur_indices]
184
+ for i in range(T0, T):
185
+ # get new feature
186
+ new_feature = img_feature[i]
187
+ normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
188
+ new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
189
+ all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
190
+ normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
191
+ all_indices = cur_indices + [[i]]
192
+ # get new similarity
193
+ all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
194
+ all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
195
+ all_sim[-1, :-1] = new_sim.T
196
+ # choose compression position
197
+ idx = torch.argmax(all_sim)
198
+ left, right = idx // (T0 + 1), idx % (T0 + 1)
199
+ if random.randint(0, 1) > 0:
200
+ idx = left
201
+ else:
202
+ idx = right
203
+ assert all_sim[left, right] == torch.max(all_sim)
204
+ # get compressed feature and similarity
205
+ cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
206
+ normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]])
207
+ cur_indices = all_indices[:idx] + all_indices[idx + 1:]
208
+ cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) # [T0, T0 + 1]
209
+ cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) # [T0, T0]
210
+ step_indices.append(cur_indices)
211
+ # print(f'Note: perform k-drop feature {img_feature.shape} to {cur_feature.shape}')
212
+ return cur_feature, None, step_indices
213
+
214
+
215
+ def k_merge_feature(img_feature, video_max_frames, img_similarity=None):
216
+ T, P, D = img_feature.shape
217
+ indices = [[i] for i in range(T)]
218
+ T0 = video_max_frames
219
+ if T <= T0:
220
+ return img_feature, img_similarity, [indices]
221
+ cur_feature = img_feature[:T0] # [T0, P, D]
222
+ normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
223
+ cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
224
+ cur_sim.fill_diagonal_(-100.0)
225
+ cur_indices = indices[:T0]
226
+ step_indices = [cur_indices]
227
+ for i in range(T0, T):
228
+ # get new feature
229
+ new_feature = img_feature[i]
230
+ normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
231
+ new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
232
+ all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
233
+ normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
234
+ all_indices = cur_indices + [[i]]
235
+ # get new similarity
236
+ all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
237
+ all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
238
+ all_sim[-1, :-1] = new_sim.T
239
+ # choose compression position
240
+ idx = torch.argmax(all_sim)
241
+ left, right = idx // (T0 + 1), idx % (T0 + 1)
242
+ assert all_sim[left, right] == torch.max(all_sim)
243
+ # update feature
244
+ all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0
245
+ normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1)
246
+ all_indices[right] = all_indices[left] + all_indices[right]
247
+ # update similarity
248
+ new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) # [T0 + 1, 1]
249
+ all_sim[right, :] = new_sim.T
250
+ all_sim[:, right:right+1] = new_sim
251
+ all_sim[right, right] = -100.0
252
+ # get compressed feature and similarity
253
+ cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]])
254
+ normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]])
255
+ cur_indices = all_indices[:left] + all_indices[left + 1:]
256
+ cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) # [T0, T0 + 1]
257
+ cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) # [T0, T0]
258
+ step_indices.append(cur_indices)
259
+ # print(f'Note: perform k-merge feature {img_feature.shape} to {cur_feature.shape}')
260
+ return cur_feature, cur_sim, step_indices
261
+
262
+
263
+ def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2):
264
+ T, P, D = img_feature.shape
265
+ T0 = video_max_frames
266
+ if T <= T0:
267
+ return img_feature, None
268
+ cur_feature = img_feature[:T0] # [T0, P, D]
269
+ turing_memory = cur_feature.reshape(T0*P, D) # [T0*P, D]
270
+ for i in range(T0, T, T0):
271
+ j = min(i + T0, T)
272
+ new_feature = img_feature[i:j] # [P, D]
273
+ new_feature = new_feature.reshape(-1, D) # [n*P, D]
274
+ turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) # [T0*P, n*P]
275
+ cur_feature = turing_memory.reshape(T0, P, D)
276
+ # print(f'Note: perform {attention_fn.__name__} feature {img_feature.shape} to {cur_feature.shape}')
277
+ return cur_feature, None
flash_vstream/model/language_model/vstream_llama.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # ------------------------------------------------------------------------
3
+ # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4
+ # Copyright 2023 Haotian Liu
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import List, Optional, Tuple, Union
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ LlamaConfig, LlamaModel, LlamaForCausalLM
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from flash_vstream.model.vstream_arch import VStreamMetaModel, VStreamMetaForCausalLM
26
+
27
+
28
+ class VStreamConfig(LlamaConfig):
29
+ model_type = "vstream"
30
+
31
+
32
+ class VStreamLlamaModel(VStreamMetaModel, LlamaModel):
33
+ config_class = VStreamConfig
34
+
35
+ def __init__(self, config: LlamaConfig):
36
+ super(VStreamLlamaModel, self).__init__(config)
37
+
38
+
39
+ class VStreamLlamaForCausalLM(VStreamMetaForCausalLM, LlamaForCausalLM):
40
+ config_class = VStreamConfig
41
+
42
+ def __init__(self, config):
43
+ super(VStreamLlamaForCausalLM, self).__init__(config)
44
+ self.model = VStreamLlamaModel(config)
45
+ self.pretraining_tp = config.pretraining_tp
46
+ self.vocab_size = config.vocab_size
47
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
48
+
49
+ # Initialize weights and apply final processing
50
+ self.post_init()
51
+
52
+ def get_model(self):
53
+ return self.model
54
+
55
+ def forward(
56
+ self,
57
+ input_ids: torch.LongTensor = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ position_ids: Optional[torch.LongTensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = True,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ features: Optional[torch.FloatTensor] = None,
68
+ return_dict: Optional[bool] = None,
69
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
70
+ if inputs_embeds is None:
71
+ if self.use_video_streaming_mode:
72
+ (
73
+ input_ids,
74
+ position_ids,
75
+ attention_mask,
76
+ past_key_values,
77
+ inputs_embeds,
78
+ labels
79
+ ) = self.prepare_inputs_labels_for_multimodal_streaming(
80
+ input_ids,
81
+ position_ids,
82
+ attention_mask,
83
+ past_key_values,
84
+ labels,
85
+ )
86
+ else:
87
+ (
88
+ input_ids,
89
+ position_ids,
90
+ attention_mask,
91
+ past_key_values,
92
+ inputs_embeds,
93
+ labels
94
+ ) = self.prepare_inputs_labels_for_multimodal(
95
+ input_ids,
96
+ position_ids,
97
+ attention_mask,
98
+ past_key_values,
99
+ labels,
100
+ images,
101
+ features,
102
+ )
103
+ return super().forward(
104
+ input_ids=input_ids,
105
+ attention_mask=attention_mask,
106
+ position_ids=position_ids,
107
+ past_key_values=past_key_values,
108
+ inputs_embeds=inputs_embeds,
109
+ labels=labels,
110
+ use_cache=use_cache,
111
+ output_attentions=output_attentions,
112
+ output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict
114
+ )
115
+
116
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
117
+ images = kwargs.pop("images", None)
118
+ features = kwargs.pop("features", None)
119
+ _inputs = super().prepare_inputs_for_generation(
120
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
121
+ )
122
+ if images is not None:
123
+ _inputs['images'] = images
124
+ if features is not None:
125
+ _inputs['features'] = features
126
+ return _inputs
127
+
128
+ AutoConfig.register("vstream", VStreamConfig)
129
+ AutoModelForCausalLM.register(VStreamConfig, VStreamLlamaForCausalLM)
flash_vstream/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import os
4
+ from .clip_encoder import CLIPVisionTower
5
+
6
+
7
+ def build_vision_tower(vision_tower_cfg, **kwargs):
8
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
9
+ is_absolute_path_exists = os.path.exists(vision_tower)
10
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
11
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
12
+
13
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
flash_vstream/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
7
+
8
+
9
+ class CLIPVisionTower(nn.Module):
10
+ def __init__(self, vision_tower, args, delay_load=False):
11
+ super().__init__()
12
+
13
+ self.is_loaded = False
14
+
15
+ self.vision_tower_name = vision_tower
16
+ self.select_layer = args.mm_vision_select_layer
17
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
18
+
19
+ if not delay_load:
20
+ self.load_model()
21
+ else:
22
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23
+
24
+ def load_model(self):
25
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
26
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
27
+ self.vision_tower.requires_grad_(False)
28
+
29
+ self.is_loaded = True
30
+
31
+ def feature_select(self, image_forward_outs):
32
+ image_features = image_forward_outs.hidden_states[self.select_layer]
33
+ if self.select_feature == 'patch':
34
+ image_features = image_features[:, 1:]
35
+ elif self.select_feature == 'cls_patch':
36
+ image_features = image_features
37
+ else:
38
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
39
+ return image_features
40
+
41
+ @torch.no_grad()
42
+ def forward(self, images):
43
+ if type(images) is list:
44
+ image_features = []
45
+ for image in images:
46
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
47
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
48
+ image_features.append(image_feature)
49
+ else:
50
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
51
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
52
+
53
+ return image_features
54
+
55
+ @property
56
+ def dummy_feature(self):
57
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
58
+
59
+ @property
60
+ def dtype(self):
61
+ return self.vision_tower.dtype
62
+
63
+ @property
64
+ def device(self):
65
+ return self.vision_tower.device
66
+
67
+ @property
68
+ def config(self):
69
+ if self.is_loaded:
70
+ return self.vision_tower.config
71
+ else:
72
+ return self.cfg_only
73
+
74
+ @property
75
+ def hidden_size(self):
76
+ return self.config.hidden_size
77
+
78
+ @property
79
+ def num_patches(self):
80
+ return (self.config.image_size // self.config.patch_size) ** 2
flash_vstream/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import re
6
+
7
+
8
+ class IdentityMap(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, *args, **kwargs):
13
+ return x
14
+
15
+ @property
16
+ def config(self):
17
+ return {"mm_projector_type": 'identity'}
18
+
19
+
20
+ class SimpleResBlock(nn.Module):
21
+ def __init__(self, channels):
22
+ super().__init__()
23
+ self.pre_norm = nn.LayerNorm(channels)
24
+
25
+ self.proj = nn.Sequential(
26
+ nn.Linear(channels, channels),
27
+ nn.GELU(),
28
+ nn.Linear(channels, channels)
29
+ )
30
+ def forward(self, x):
31
+ x = self.pre_norm(x)
32
+ return x + self.proj(x)
33
+
34
+
35
+ def build_vision_projector(config, input_dim, delay_load=False, **kwargs):
36
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
37
+
38
+ if projector_type == 'linear':
39
+ return nn.Linear(input_dim, config.hidden_size)
40
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(input_dim, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+ if projector_type == 'identity':
49
+ return IdentityMap()
50
+
51
+ raise ValueError(f'Unknown projector type: {projector_type}')
flash_vstream/model/vstream_arch.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # ------------------------------------------------------------------------
3
+ # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4
+ # Copyright 2023 Haotian Liu
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import time
19
+ import math
20
+ import logging
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch.multiprocessing import Lock, Manager
25
+
26
+ from abc import ABC, abstractmethod
27
+ from flash_vstream.model.multimodal_encoder.builder import build_vision_tower
28
+ from flash_vstream.model.multimodal_projector.builder import build_vision_projector
29
+ from flash_vstream.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
30
+
31
+ from flash_vstream.model.compress_functions import drop_feature, merge_feature, kmeans_feature, weighted_kmeans_feature, k_drop_feature, k_merge_feature, attention_feature
32
+
33
+
34
+ class NeuralTuringMachine(nn.Module):
35
+ def __init__(self, input_dim=1024, output_dim=1024, attention_dropout=0.1):
36
+ super(NeuralTuringMachine, self).__init__()
37
+ self.input_dim = input_dim
38
+ self.output_dim = output_dim
39
+ self.q_proj = nn.Linear(input_dim, output_dim)
40
+ self.k_proj = nn.Linear(input_dim, output_dim)
41
+ self.v_proj = nn.Linear(input_dim, output_dim)
42
+ self.dropout = nn.Dropout(attention_dropout)
43
+ self.out_proj = nn.Linear(output_dim, input_dim)
44
+ self.out_dropout = nn.Dropout(attention_dropout)
45
+ self.out_ln = nn.LayerNorm(input_dim, eps=1e-12)
46
+
47
+ def get_weight(self, x, y):
48
+ query = self.q_proj(x)
49
+ key = self.k_proj(y)
50
+ scores = torch.matmul(query, key.transpose(0, 1)) / math.sqrt(self.output_dim)
51
+ weight = F.softmax(scores, dim=-1)
52
+ return weight
53
+
54
+ def forward(self, x, y):
55
+ query = self.q_proj(x)
56
+ key = self.k_proj(y)
57
+ scores = torch.matmul(query, key.transpose(0, 1)) / math.sqrt(self.output_dim)
58
+ weight = F.softmax(scores, dim=-1)
59
+ attn = self.dropout(weight)
60
+ value = self.v_proj(y)
61
+ output = torch.matmul(attn, value)
62
+ output = self.out_proj(output)
63
+ output = self.out_dropout(output)
64
+ output = self.out_ln(output.unsqueeze(0)).squeeze(0)
65
+ return output
66
+
67
+
68
+ class VStreamMetaModel:
69
+
70
+ def __init__(self, config):
71
+ super(VStreamMetaModel, self).__init__(config)
72
+
73
+ self.mm_input_dim = config.mm_hidden_size
74
+ if getattr(config, 'mm_use_4_vision_tokens', False):
75
+ self.mm_input_dim = self.mm_input_dim * 4
76
+
77
+ if hasattr(config, "mm_vision_tower"):
78
+ self.vision_tower = build_vision_tower(config, delay_load=True)
79
+ self.mm_projector = build_vision_projector(config, self.mm_input_dim)
80
+
81
+ compress_Turing_hidden_dim = getattr(self.config, "compress_Turing_hidden_dim", 32)
82
+ self.attention_model = NeuralTuringMachine(self.mm_input_dim, compress_Turing_hidden_dim)
83
+
84
+ def get_vision_tower(self):
85
+ vision_tower = getattr(self, 'vision_tower', None)
86
+ if type(vision_tower) is list:
87
+ vision_tower = vision_tower[0]
88
+ return vision_tower
89
+
90
+ def initialize_vision_modules(self, model_args, fsdp=None):
91
+ vision_tower = model_args.vision_tower
92
+ mm_vision_select_layer = model_args.mm_vision_select_layer
93
+ mm_vision_select_feature = model_args.mm_vision_select_feature
94
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
95
+
96
+ self.config.mm_vision_tower = vision_tower
97
+
98
+ if self.get_vision_tower() is None:
99
+ vision_tower = build_vision_tower(model_args)
100
+
101
+ if fsdp is not None and len(fsdp) > 0:
102
+ self.vision_tower = [vision_tower]
103
+ else:
104
+ self.vision_tower = vision_tower
105
+ else:
106
+ if fsdp is not None and len(fsdp) > 0:
107
+ vision_tower = self.vision_tower[0]
108
+ else:
109
+ vision_tower = self.vision_tower
110
+ vision_tower.load_model()
111
+
112
+ self.config.use_mm_proj = True
113
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
114
+ self.config.mm_hidden_size = vision_tower.hidden_size
115
+ self.config.mm_vision_select_layer = mm_vision_select_layer
116
+ self.config.mm_vision_select_feature = mm_vision_select_feature
117
+
118
+ self.config.compress_type = getattr(model_args, "compress_type", None)
119
+ self.config.compress_size = getattr(model_args, "compress_size", 1)
120
+ self.config.compress_long_memory_size = getattr(model_args, "compress_long_memory_size", 1)
121
+ self.config.compress_Turing_memory_size = getattr(model_args, "compress_Turing_memory_size", 1)
122
+ self.config.compress_Turing_update_ratio = getattr(model_args, "compress_Turing_update_ratio", 0.2)
123
+ self.config.video_max_frames = getattr(model_args, "video_max_frames", 50)
124
+ self.config.video_long_memory_length = getattr(model_args, "video_long_memory_length", 10)
125
+ self.config.video_Turing_memory_length = getattr(model_args, "video_Turing_memory_length", 10)
126
+ self.config.video_short_memory_length = getattr(model_args, "video_short_memory_length", 10)
127
+ self.config.video_current_memory_length = getattr(model_args, "video_current_memory_length", 1)
128
+ self.config.video_sample_type = getattr(model_args, "video_sample_type", "center")
129
+
130
+ if getattr(self, 'mm_projector', None) is None:
131
+ self.mm_projector = build_vision_projector(self.config)
132
+ else:
133
+ # In case it is frozen by LoRA
134
+ for p in self.mm_projector.parameters():
135
+ p.requires_grad = True
136
+
137
+ if pretrain_mm_mlp_adapter is not None:
138
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
139
+ def get_w(weights, keyword):
140
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
141
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
142
+
143
+ class VStreamMetaForCausalLM(ABC):
144
+
145
+ def __init__(self, config):
146
+ super(VStreamMetaForCausalLM, self).__init__(config)
147
+ # support video streaming mode
148
+ self.use_video_streaming_mode = False
149
+ self.video_embedding_memory = None # set to torch.multiprocessing.Manager.list() when launching
150
+ self.video_embedding_mem_lock = Lock()
151
+
152
+ @abstractmethod
153
+ def get_model(self):
154
+ pass
155
+
156
+ def get_vision_tower(self):
157
+ return self.get_model().get_vision_tower()
158
+
159
+ def encode_images(self, images):
160
+ image_features = self.get_model().get_vision_tower()(images)
161
+ return image_features
162
+
163
+ def reshape_2x2_image_features(self, image_features):
164
+ B, P, D = image_features.shape
165
+ patch_size = round(math.sqrt(P))
166
+ assert patch_size % 2 == 0, "Patch size must be divisible by 2."
167
+ image_features = image_features.reshape(B, patch_size, patch_size, D)
168
+ image_features_2x2 = image_features.reshape(B, patch_size // 2, 2, patch_size // 2, 2, D)
169
+ image_features_2x2 = image_features_2x2.permute(0, 1, 3, 2, 4, 5)
170
+ image_features_2x2 = image_features_2x2.reshape(B, patch_size // 2, patch_size // 2, 4 * D) # concat 2x2 neighbor patches
171
+ image_features = image_features_2x2.reshape(B, (patch_size // 2) ** 2, 4 * D)
172
+ return image_features
173
+
174
+ def attention(self, turing_memory, new_feature, update_ratio=0.2):
175
+ T1, D1 = turing_memory.shape
176
+ T2, D2 = new_feature.shape
177
+ assert D1 == D2, f"dimmension not match, {D1} != {D2}"
178
+ model = self.get_model().attention_model
179
+ weight = model.get_weight(turing_memory, new_feature)
180
+ weight = weight * update_ratio # [T1, T2]
181
+ decay = weight.sum(dim=1, keepdim=True) # [T0*P, 1], 表示当前NTM memory和新来的feat的相似度
182
+ turing_memory = turing_memory * (1 - decay) + torch.mm(weight, new_feature)
183
+ return turing_memory
184
+
185
+ def attention2(self, turing_memory, new_feature, update_ratio=0.2): # deprecated
186
+ T1, D1 = turing_memory.shape
187
+ T2, D2 = new_feature.shape
188
+ assert D1 == D2, f"dimmension not match, {D1} != {D2}"
189
+ model = self.get_model().attention_model
190
+ turing_memory = model.forward(turing_memory, new_feature)
191
+ return turing_memory
192
+
193
+ def compress_spatial_features(self, image_features, compress_size=1):
194
+ compress_type = getattr(self.config, "compress_type", None)
195
+ patch_size = round(math.sqrt(image_features.shape[1]))
196
+ assert patch_size * patch_size == image_features.shape[1], f"For ViT feature map, {patch_size}*{patch_size}={patch_size**2} != {image_features.shape[1]}"
197
+ if patch_size == compress_size:
198
+ return image_features
199
+ elif compress_type is not None:
200
+ if 'mean' in self.config.compress_type:
201
+ # TODO: currently use 1 token per frame (or image), direct poolt
202
+ if compress_size == 1:
203
+ image_features = image_features.mean(dim=1, keepdim=True)
204
+ else:
205
+ image_features = image_features.view(-1, patch_size, patch_size, image_features.shape[-1])
206
+ image_features = image_features.permute(0, 3, 1, 2) # [B*T, D, P, P]
207
+ pooled_features = F.avg_pool2d(image_features, (patch_size // compress_size, patch_size // compress_size))
208
+ pooled_features = pooled_features.permute(0, 2, 3, 1) # [B*T, P, P, D]
209
+ image_features = pooled_features.view(-1, compress_size * compress_size, pooled_features.shape[-1])
210
+ else:
211
+ raise NotImplementedError(f"`compress_type` {self.config.compress_type} is not supported yet.")
212
+ return image_features
213
+
214
+ def compress_temporal_features(self, image_features):
215
+ video_long_memory_length = getattr(self.config, "video_long_memory_length", 10)
216
+ video_Turing_memory_length = getattr(self.config, "video_Turing_memory_length", 10)
217
+ video_short_memory_length = getattr(self.config, "video_short_memory_length", 10) # not used
218
+ video_current_memory_length = getattr(self.config, "video_current_memory_length", 1)
219
+ compress_long_memory_size = getattr(self.config, "compress_long_memory_size", 1)
220
+ compress_Turing_memory_size = getattr(self.config, "compress_Turing_memory_size", 1)
221
+ compress_Turing_update_ratio = getattr(self.config, "compress_Turing_update_ratio", 0.2)
222
+ compress_fn_dic = {
223
+ 'drop': drop_feature,
224
+ 'merge': merge_feature,
225
+ 'kmeans': kmeans_feature,
226
+ 'weighted_kmeans': weighted_kmeans_feature,
227
+ 'kdrop': k_drop_feature,
228
+ 'kmerge': k_merge_feature,
229
+ 'attention': attention_feature,
230
+ }
231
+ compress_type = self.config.video_sample_type
232
+ if compress_type in compress_fn_dic:
233
+ compress_fn = compress_fn_dic[compress_type]
234
+ else:
235
+ raise NotImplementedError(f'max_length = {self.config.video_max_frames},'
236
+ f'while video_sample_type = {compress_type} is not supported yet.')
237
+ new_image_features = []
238
+ step_indices = []
239
+ step_features = []
240
+ for img_feature in image_features: # [T, P*P, D]
241
+ cur_start = min(video_current_memory_length, img_feature.shape[0])
242
+ ### Calc Spatial Memory
243
+ if cur_start == 0:
244
+ cur_memory = img_feature[:0]
245
+ long_memory = img_feature
246
+ Turing_memory = img_feature
247
+ else:
248
+ cur_memory = img_feature[-cur_start:] # [C, P*P, D]
249
+ long_memory = img_feature[:-cur_start] # [L, P*P, D]
250
+ Turing_memory = img_feature[:-cur_start] # [L, P*P, D]
251
+ if compress_long_memory_size * compress_long_memory_size != long_memory.shape[1]:
252
+ long_memory = self.compress_spatial_features(long_memory, compress_long_memory_size) # [L, P'*P', D]
253
+ if compress_Turing_memory_size * compress_Turing_memory_size != Turing_memory.shape[1]:
254
+ Turing_memory = self.compress_spatial_features(Turing_memory, compress_Turing_memory_size) # [L, P'*P', D]
255
+ ### Calc Temporal Memory
256
+ if video_long_memory_length == 0 or long_memory.shape[0] == 0:
257
+ long_memory_compreesed = long_memory[:0]
258
+ else:
259
+ long_memory_compreesed, weight, step_long_indices = compress_fn(long_memory, video_long_memory_length) # [L_long, P'*P', D], [L_long]
260
+ ### Calc Retrieved Memory
261
+ sorted_indices = torch.argsort(weight, descending=True) # [L_long]
262
+ key_centroids = long_memory[sorted_indices] # [L_long, P'*P', D]
263
+ key_length = 3
264
+ if key_centroids.shape[0] > key_length:
265
+ key_centroids = key_centroids[:key_length]
266
+ dists = ((long_memory.unsqueeze(1) - key_centroids.unsqueeze(0)) ** 2).sum(dim=3).sum(dim=2).sqrt() # [L_long, k_L]
267
+ min_indices = torch.argmin(dists, dim=0) # [k_L]
268
+ key_memory = img_feature[min_indices]
269
+ cur_memory = torch.cat([key_memory, cur_memory], dim=0)
270
+ ### Calc Abstract Memory
271
+ if video_Turing_memory_length == 0 or Turing_memory.shape[0] == 0:
272
+ Turing_memory_compreesed = Turing_memory[:0]
273
+ else:
274
+ Turing_memory_compreesed, _ = attention_feature(Turing_memory, video_Turing_memory_length, self.attention, update_ratio=compress_Turing_update_ratio)
275
+ memory_feature = torch.cat([Turing_memory_compreesed.flatten(0, 1), long_memory_compreesed.flatten(0, 1), cur_memory.flatten(0, 1)], dim=0)
276
+ new_image_features.append(memory_feature)
277
+ return new_image_features
278
+
279
+ def cat_proj(self, all_features): # concatenate features and project them together
280
+ feature_split_size = [x.shape[0] for x in all_features]
281
+ feature_embed = torch.cat(all_features, dim=0)
282
+ feature_proj = self.get_model().mm_projector(feature_embed)
283
+ feature_proj = torch.split(feature_proj, feature_split_size, dim=0)
284
+ return feature_proj
285
+
286
+ def prepare_inputs_labels_for_multimodal(
287
+ self,
288
+ input_ids,
289
+ position_ids,
290
+ attention_mask,
291
+ past_key_values,
292
+ labels,
293
+ images,
294
+ features
295
+ ):
296
+ vision_tower = self.get_vision_tower()
297
+ if vision_tower is None or (images is None and features is None) or input_ids.shape[1] == 1:
298
+ if past_key_values is not None and vision_tower is not None and ((images is not None) or (features is not None)) and input_ids.shape[1] == 1:
299
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
300
+ if target_shape - attention_mask.shape[1] >= 0:
301
+ attention_mask = torch.cat((attention_mask, torch.ones(
302
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
303
+ dtype=attention_mask.dtype,
304
+ device=attention_mask.device
305
+ )), dim=1)
306
+ elif target_shape - attention_mask.shape[1] < 0:
307
+ attention_mask = attention_mask[:, :target_shape]
308
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
309
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
310
+
311
+ if (features is not None) or (type(images) is list) or (images.ndim == 5):
312
+ compress_size = getattr(self.config, "compress_size", 1)
313
+ if images is not None:
314
+ images = [image if len(image.shape) == 4 else image.unsqueeze(0) for image in images] # [B, T, C, H, W]
315
+ concat_images = torch.cat([image for image in images], dim=0) # [B*T, C, H, W]
316
+ image_features = self.encode_images(concat_images) # [B*T, P, D]
317
+ if getattr(self.config, 'mm_use_4_vision_tokens', False):
318
+ image_features = self.reshape_2x2_image_features(image_features) # [B*T, P/4, 4*D]
319
+ image_features = self.compress_spatial_features(image_features, compress_size) # [B*T, P', D]
320
+ split_sizes = [image.shape[0] for image in images]
321
+ image_features = torch.split(image_features, split_sizes, dim=0) # [B, T, P, D]
322
+ else:
323
+ image_features = [feat if len(feat.shape) == 3 else feat.unsqueeze(0) for feat in features]
324
+ origin_img_features = image_features
325
+ if getattr(self.config, 'mm_use_4_vision_tokens', False):
326
+ image_features = [self.reshape_2x2_image_features(img_feature) for img_feature in image_features] # [B*T, P/4, 4*D]
327
+ image_features = [self.compress_spatial_features(image_feature, compress_size) for image_feature in image_features] # [B*T, P', D]
328
+ # perform memory consolidation
329
+ image_features = self.compress_temporal_features(image_features) # [B, TP, D]
330
+ image_features = [x.to(self.device) for x in image_features] # [B, TP, D]
331
+ image_features = self.cat_proj(image_features)
332
+ else:
333
+ image_features = self.encode_images(images).to(self.device) # [B, 576, 2048]
334
+ if getattr(self.config, 'mm_use_4_vision_tokens', False):
335
+ image_features = self.reshape_2x2_image_features(image_features) # [B*T, P/4, 4*D]
336
+ image_features = self.get_model().mm_projector(image_features)
337
+
338
+ # TODO: image start / end is not implemented here to support pretraining.
339
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
340
+ raise NotImplementedError
341
+
342
+ _labels = labels
343
+ _position_ids = position_ids
344
+ _attention_mask = attention_mask
345
+ if attention_mask is None:
346
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
347
+ else:
348
+ attention_mask = attention_mask.bool()
349
+ if position_ids is None:
350
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
351
+ if labels is None:
352
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
353
+
354
+ # remove the padding using attention_mask -- TODO: double check
355
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
356
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
357
+ new_input_embeds = []
358
+ new_labels = []
359
+ cur_image_idx = 0
360
+ for batch_idx, cur_input_ids in enumerate(input_ids):
361
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
362
+ if num_images == 0:
363
+ cur_image_features = image_features[cur_image_idx]
364
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
365
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
366
+ new_input_embeds.append(cur_input_embeds)
367
+ new_labels.append(labels[batch_idx])
368
+ cur_image_idx += 1
369
+ continue
370
+
371
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] # only input first image_token
372
+ cur_input_ids_noim = []
373
+ cur_labels = labels[batch_idx]
374
+ cur_labels_noim = []
375
+ for i in range(len(image_token_indices) - 1):
376
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
377
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
378
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
379
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
380
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
381
+ cur_new_input_embeds = []
382
+ cur_new_labels = []
383
+
384
+ for i in range(num_images + 1):
385
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
386
+ cur_new_labels.append(cur_labels_noim[i])
387
+ if i < num_images:
388
+ cur_image_features = image_features[cur_image_idx]
389
+ cur_image_idx += 1
390
+ cur_new_input_embeds.append(cur_image_features)
391
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
392
+
393
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
394
+ cur_new_labels = torch.cat(cur_new_labels)
395
+
396
+ new_input_embeds.append(cur_new_input_embeds)
397
+ new_labels.append(cur_new_labels)
398
+ assert cur_image_idx == batch_idx + 1
399
+
400
+ # Truncate sequences to max length as image embeddings can make the sequence longer
401
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
402
+ if tokenizer_model_max_length is not None:
403
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
404
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
405
+
406
+ # Combine them
407
+ max_len = max(x.shape[0] for x in new_input_embeds)
408
+ batch_size = len(new_input_embeds)
409
+
410
+ new_input_embeds_padded = []
411
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
412
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
413
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
414
+
415
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
416
+ cur_len = cur_new_embed.shape[0]
417
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
418
+ new_input_embeds_padded.append(torch.cat((
419
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
420
+ cur_new_embed
421
+ ), dim=0))
422
+ if cur_len > 0:
423
+ new_labels_padded[i, -cur_len:] = cur_new_labels
424
+ attention_mask[i, -cur_len:] = True
425
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
426
+ else:
427
+ new_input_embeds_padded.append(torch.cat((
428
+ cur_new_embed,
429
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
430
+ ), dim=0))
431
+ if cur_len > 0:
432
+ new_labels_padded[i, :cur_len] = cur_new_labels
433
+ attention_mask[i, :cur_len] = True
434
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
435
+
436
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
437
+
438
+ if _labels is None:
439
+ new_labels = None
440
+ else:
441
+ new_labels = new_labels_padded
442
+
443
+ if _attention_mask is None:
444
+ attention_mask = None
445
+ else:
446
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
447
+
448
+ if _position_ids is None:
449
+ position_ids = None
450
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
451
+
452
+ def prepare_inputs_labels_for_multimodal_streaming( # Asynchronous encoding with a SemLock, only for videos, batch_size=1
453
+ self,
454
+ input_ids,
455
+ position_ids,
456
+ attention_mask,
457
+ past_key_values,
458
+ labels
459
+ ):
460
+ assert self.use_video_streaming_mode
461
+ logger = logging.getLogger(__name__)
462
+ vision_tower = self.get_vision_tower()
463
+ if vision_tower is None or input_ids.shape[1] == 1:
464
+ if past_key_values is not None and vision_tower is not None and input_ids.shape[1] == 1:
465
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
466
+ if target_shape - attention_mask.shape[1] >= 0:
467
+ attention_mask = torch.cat((attention_mask, torch.ones(
468
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
469
+ dtype=attention_mask.dtype,
470
+ device=attention_mask.device
471
+ )), dim=1)
472
+ elif target_shape - attention_mask.shape[1] < 0:
473
+ attention_mask = attention_mask[:, :target_shape]
474
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
475
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
476
+ # Have some tries to avoid deadlock
477
+ attempt_times = 0
478
+ while attempt_times < 300:
479
+ try:
480
+ with self.video_embedding_mem_lock:
481
+ cur_memory, long_memory_compreesed, Turing_memory_compreesed, _ = self.video_embedding_memory
482
+ logger.info(f'Read cur_memory={cur_memory.shape} {cur_memory.dtype}, long_memory_compreesed={long_memory_compreesed.shape} {long_memory_compreesed.dtype}, Turing_memory_compreesed={Turing_memory_compreesed.shape} {Turing_memory_compreesed.dtype}')
483
+ image_feature = torch.cat([Turing_memory_compreesed.flatten(0, 1), long_memory_compreesed.flatten(0, 1), cur_memory.flatten(0, 1)], dim=0)
484
+ image_features = [image_feature.to(self.device)]
485
+ break
486
+
487
+ except Exception as e:
488
+ logger.error(f'Attempt:{attempt_times} Failed to get video features, Error: {e}')
489
+ image_features = []
490
+ time.sleep(0.1)
491
+ attempt_times += 1
492
+
493
+ image_features = [x.to(self.device) for x in image_features] # [B, TP, D]
494
+ image_features = self.cat_proj(image_features)
495
+
496
+ # TODO: image start / end is not implemented here to support pretraining.
497
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
498
+ raise NotImplementedError
499
+
500
+ _labels = labels
501
+ _position_ids = position_ids
502
+ _attention_mask = attention_mask
503
+ if attention_mask is None:
504
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
505
+ else:
506
+ attention_mask = attention_mask.bool()
507
+ if position_ids is None:
508
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
509
+ if labels is None:
510
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
511
+
512
+ # remove the padding using attention_mask -- TODO: double check
513
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
514
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
515
+
516
+ new_input_embeds = []
517
+ new_labels = []
518
+ cur_image_idx = 0
519
+ for batch_idx, cur_input_ids in enumerate(input_ids):
520
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
521
+ if num_images == 0:
522
+ cur_image_features = image_features[cur_image_idx]
523
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
524
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
525
+ new_input_embeds.append(cur_input_embeds)
526
+ new_labels.append(labels[batch_idx])
527
+ cur_image_idx += 1
528
+ continue
529
+
530
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] # only input first image_token
531
+ cur_input_ids_noim = []
532
+ cur_labels = labels[batch_idx]
533
+ cur_labels_noim = []
534
+ for i in range(len(image_token_indices) - 1):
535
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
536
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
537
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
538
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
539
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
540
+ cur_new_input_embeds = []
541
+ cur_new_labels = []
542
+
543
+ for i in range(num_images + 1):
544
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
545
+ cur_new_labels.append(cur_labels_noim[i])
546
+ if i < num_images:
547
+ cur_image_features = image_features[cur_image_idx]
548
+ cur_image_idx += 1
549
+ cur_new_input_embeds.append(cur_image_features)
550
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
551
+
552
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
553
+ cur_new_labels = torch.cat(cur_new_labels)
554
+
555
+ new_input_embeds.append(cur_new_input_embeds)
556
+ new_labels.append(cur_new_labels)
557
+ assert cur_image_idx == batch_idx + 1
558
+
559
+ # Truncate sequences to max length as image embeddings can make the sequence longer
560
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
561
+ if tokenizer_model_max_length is not None:
562
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
563
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
564
+
565
+ # Combine them
566
+ max_len = max(x.shape[0] for x in new_input_embeds)
567
+ batch_size = len(new_input_embeds)
568
+
569
+ new_input_embeds_padded = []
570
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
571
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
572
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
573
+
574
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
575
+ cur_len = cur_new_embed.shape[0]
576
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
577
+ new_input_embeds_padded.append(torch.cat((
578
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
579
+ cur_new_embed
580
+ ), dim=0))
581
+ if cur_len > 0:
582
+ new_labels_padded[i, -cur_len:] = cur_new_labels
583
+ attention_mask[i, -cur_len:] = True
584
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
585
+ else:
586
+ new_input_embeds_padded.append(torch.cat((
587
+ cur_new_embed,
588
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
589
+ ), dim=0))
590
+ if cur_len > 0:
591
+ new_labels_padded[i, :cur_len] = cur_new_labels
592
+ attention_mask[i, :cur_len] = True
593
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
594
+
595
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
596
+
597
+ if _labels is None:
598
+ new_labels = None
599
+ else:
600
+ new_labels = new_labels_padded
601
+
602
+ if _attention_mask is None:
603
+ attention_mask = None
604
+ else:
605
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
606
+
607
+ if _position_ids is None:
608
+ position_ids = None
609
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
610
+
611
+ def embed_video_streaming( # Asynchronous encoding with a SemLock, only for videos, batch_size=1
612
+ self,
613
+ images
614
+ ):
615
+ assert self.use_video_streaming_mode
616
+ logger = logging.getLogger(__name__)
617
+
618
+ compress_size = getattr(self.config, "compress_size", 1)
619
+ video_long_memory_length = getattr(self.config, "video_long_memory_length", 10)
620
+ video_Turing_memory_length = getattr(self.config, "video_Turing_memory_length", 10)
621
+ video_short_memory_length = getattr(self.config, "video_short_memory_length", 10) # not used
622
+ video_current_memory_length = getattr(self.config, "video_current_memory_length", 1)
623
+ compress_long_memory_size = getattr(self.config, "compress_long_memory_size", 1)
624
+ compress_Turing_memory_size = getattr(self.config, "compress_Turing_memory_size", 1)
625
+ compress_Turing_update_ratio = getattr(self.config, "compress_Turing_update_ratio", 0.2)
626
+ compress_fn_dic = {
627
+ 'drop': drop_feature,
628
+ 'merge': merge_feature,
629
+ 'kmeans': kmeans_feature,
630
+ 'weighted_kmeans': weighted_kmeans_feature,
631
+ 'kdrop': k_drop_feature,
632
+ 'kmerge': k_merge_feature,
633
+ 'uni_kmerge': k_merge_feature,
634
+ 'both_kmerge': k_merge_feature,
635
+ 'split_kmerge': k_merge_feature,
636
+ 'attention': attention_feature,
637
+ }
638
+
639
+ if type(images) is list or images.ndim == 5:
640
+ assert len(images) == 1
641
+ images = [image if len(image.shape) == 4 else image.unsqueeze(0) for image in images] # [B, T, C, H, W]
642
+ concat_images = torch.cat([image for image in images], dim=0) # [B*T, C, H, W]
643
+ image_features = self.encode_images(concat_images) # [B*T, P, D]
644
+ image_features = self.compress_spatial_features(image_features, compress_size) # [B*T, P', D]
645
+ split_sizes = [image.shape[0] for image in images]
646
+ image_features = torch.split(image_features, split_sizes, dim=0) # [B, T, P, D]
647
+ else:
648
+ raise NotImplementedError('Should input video frames, not a single image')
649
+ image_feature = image_features[0].detach().to(torch.float16).to(self.device) # [T, P, D]
650
+ img_feature_buffer = image_feature.cpu()
651
+
652
+ cur_start = min(video_current_memory_length, image_feature.shape[0])
653
+ if cur_start == 0:
654
+ cur_memory = image_feature[:0]
655
+ else:
656
+ cur_memory = image_feature[-cur_start:] # [L_c, P*P, D]
657
+ long_memory = image_feature
658
+ Turing_memory = image_feature
659
+ if compress_long_memory_size * compress_long_memory_size != long_memory.shape[1]:
660
+ long_memory = self.compress_spatial_features(long_memory, compress_long_memory_size) # [L_l, P'*P', D]
661
+ if compress_Turing_memory_size * compress_Turing_memory_size != Turing_memory.shape[1]:
662
+ Turing_memory = self.compress_spatial_features(Turing_memory, compress_Turing_memory_size) # [L_t, P'*P', D]
663
+ compress_type = self.config.video_sample_type
664
+ if compress_type in compress_fn_dic:
665
+ compress_fn = compress_fn_dic[compress_type]
666
+ else:
667
+ raise NotImplementedError(f'max_length = {self.config.video_max_frames},'
668
+ f'while video_sample_type = {compress_type} is not supported yet.')
669
+ long_memory_compreesed = long_memory
670
+ Turing_memory_compreesed = Turing_memory
671
+ # Read old memory from shared memory, do not need an I/O lock
672
+ if self.video_embedding_memory is not None and len(self.video_embedding_memory) > 0:
673
+ old_cur_memory, old_long_memory_compreesed, old_Turing_memory_compreesed, old_img_feature_buffer = self.video_embedding_memory
674
+ old_long_memory_compreesed = old_long_memory_compreesed.to(self.device)
675
+ old_Turing_memory_compreesed = old_Turing_memory_compreesed.to(self.device)
676
+ img_feature_buffer = torch.cat([old_img_feature_buffer, image_feature.cpu()], dim=0)
677
+ assert isinstance(old_long_memory_compreesed, torch.Tensor) and old_long_memory_compreesed.shape[1:] == long_memory_compreesed.shape[1:]
678
+ long_memory = torch.cat((old_long_memory_compreesed, long_memory_compreesed), dim=0)
679
+ long_memory_compreesed, weight, step_long_indices = compress_fn(long_memory, video_long_memory_length)
680
+ # Retrive key frames
681
+ sorted_indices = torch.argsort(weight, descending=True) # [L_long]
682
+ key_centroids = long_memory[sorted_indices] # [L_long, P'*P', D]
683
+ key_length = 3
684
+ if key_centroids.shape[0] > key_length:
685
+ key_centroids = key_centroids[:key_length]
686
+ dists = ((long_memory.unsqueeze(1) - key_centroids.unsqueeze(0)) ** 2).sum(dim=3).sum(dim=2).sqrt() # [L_long, k_L]
687
+ min_indices = torch.argmin(dists, dim=0) # [k_L]
688
+ key_memory = img_feature_buffer[min_indices.cpu()].to(self.device)
689
+ cur_memory = torch.cat([key_memory, cur_memory], dim=0)
690
+ Turing_memory = torch.cat((old_Turing_memory_compreesed, Turing_memory_compreesed), dim=0)
691
+ Turing_memory_compreesed, _ = attention_feature(Turing_memory, video_Turing_memory_length, self.attention, update_ratio=compress_Turing_update_ratio)
692
+ # Write to shared memory, need an I/O lock
693
+ with self.video_embedding_mem_lock:
694
+ self.video_embedding_memory[:] = [cur_memory.cpu(), long_memory_compreesed.cpu(), Turing_memory_compreesed.cpu(), img_feature_buffer] # Only change content
695
+ logger.info(f'Write cur_memory={cur_memory.shape} {cur_memory.dtype}, long_memory_compreesed={long_memory_compreesed.shape} {long_memory_compreesed.dtype}, Turing_memory_compreesed={Turing_memory_compreesed.shape} {Turing_memory_compreesed.dtype}')
696
+
697
+ return []
698
+
699
+
700
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
701
+ if model_args.mm_use_im_patch_token:
702
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
703
+ self.resize_token_embeddings(len(tokenizer))
704
+
705
+ if model_args.mm_use_im_start_end:
706
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
707
+ self.resize_token_embeddings(len(tokenizer))
708
+
709
+ if num_new_tokens > 0:
710
+ input_embeddings = self.get_input_embeddings().weight.data
711
+ output_embeddings = self.get_output_embeddings().weight.data
712
+
713
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
714
+ dim=0, keepdim=True)
715
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
716
+ dim=0, keepdim=True)
717
+
718
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
719
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
720
+
721
+ if model_args.tune_mm_mlp_adapter:
722
+ for p in self.get_input_embeddings().parameters():
723
+ p.requires_grad = True
724
+ for p in self.get_output_embeddings().parameters():
725
+ p.requires_grad = False
726
+
727
+ if model_args.pretrain_mm_mlp_adapter:
728
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
729
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
730
+ assert num_new_tokens == 2
731
+ if input_embeddings.shape == embed_tokens_weight.shape:
732
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
733
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
734
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
735
+ else:
736
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
737
+ elif model_args.mm_use_im_patch_token:
738
+ if model_args.tune_mm_mlp_adapter:
739
+ for p in self.get_input_embeddings().parameters():
740
+ p.requires_grad = False
741
+ for p in self.get_output_embeddings().parameters():
742
+ p.requires_grad = False
flash_vstream/serve/cli_video_stream.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # Based on https://github.com/haotian-liu/LLaVA.
3
+ """
4
+ This file demonstrates an implementation of a multiprocess Real-time Long Video Understanding System. With a multiprocess logging module.
5
+ main process: CLI server I/O, LLM inference
6
+ process-1: logger listener
7
+ process-2: frame generator,
8
+ process-3: frame memory manager
9
+ Author: Haoji Zhang, Haotian Liu
10
+ (This code is based on https://github.com/haotian-liu/LLaVA)
11
+ """
12
+ import argparse
13
+ import requests
14
+ import logging
15
+ import torch
16
+ import numpy as np
17
+ import time
18
+ import os
19
+
20
+ from flash_vstream.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
21
+ from flash_vstream.conversation import conv_templates, SeparatorStyle
22
+ from flash_vstream.model.builder import load_pretrained_model
23
+ from flash_vstream.utils import disable_torch_init
24
+ from flash_vstream.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
25
+
26
+ from torch.multiprocessing import Process, Queue, Manager
27
+ from transformers import TextStreamer
28
+ from decord import VideoReader
29
+ from datetime import datetime
30
+ from PIL import Image
31
+ from io import BytesIO
32
+
33
+ class _Metric:
34
+ def __init__(self):
35
+ self._latest_value = None
36
+ self._sum = 0.0
37
+ self._max = 0.0
38
+ self._count = 0
39
+
40
+ @property
41
+ def val(self):
42
+ return self._latest_value
43
+
44
+ @property
45
+ def max(self):
46
+ return self._max
47
+
48
+ @property
49
+ def avg(self):
50
+ if self._count == 0:
51
+ return float('nan')
52
+ return self._sum / self._count
53
+
54
+ def add(self, value):
55
+ self._latest_value = value
56
+ self._sum += value
57
+ self._count += 1
58
+ if value > self._max:
59
+ self._max = value
60
+
61
+ def __str__(self):
62
+ latest_formatted = f"{self.val:.6f}" if self.val is not None else "None"
63
+ average_formatted = f"{self.avg:.6f}"
64
+ max_formatted = f"{self.max:.6f}"
65
+ return f"{latest_formatted} ({average_formatted}, {max_formatted})"
66
+
67
+
68
+ class MetricMeter:
69
+ def __init__(self):
70
+ self._metrics = {}
71
+
72
+ def add(self, key, value):
73
+ if key not in self._metrics:
74
+ self._metrics[key] = _Metric()
75
+ self._metrics[key].add(value)
76
+
77
+ def val(self, key):
78
+ metric = self._metrics.get(key)
79
+ if metric is None or metric.val is None:
80
+ raise ValueError(f"No values have been added for key '{key}'.")
81
+ return metric.val
82
+
83
+ def avg(self, key):
84
+ metric = self._metrics.get(key)
85
+ if metric is None:
86
+ raise ValueError(f"No values have been added for key '{key}'.")
87
+ return metric.avg
88
+
89
+ def max(self, key):
90
+ metric = self._metrics.get(key)
91
+ if metric is None:
92
+ raise ValueError(f"No values have been added for key '{key}'.")
93
+ return metric.max
94
+
95
+ def __getitem__(self, key):
96
+ metric = self._metrics.get(key)
97
+ if metric is None:
98
+ raise KeyError(f"The key '{key}' does not exist.")
99
+ return str(metric)
100
+
101
+ def load_image(image_file):
102
+ if image_file.startswith('http://') or image_file.startswith('https://'):
103
+ response = requests.get(image_file)
104
+ image = Image.open(BytesIO(response.content)).convert('RGB')
105
+ else:
106
+ image = Image.open(image_file).convert('RGB')
107
+ return image
108
+
109
+ def listener(queue, filename):
110
+ ############## Start sub process-1: Listener #############
111
+ import sys, traceback
112
+ root = logging.getLogger()
113
+ root.setLevel(logging.DEBUG)
114
+ # h = logging.StreamHandler(sys.stdout)
115
+ h = logging.FileHandler(filename)
116
+ f = logging.Formatter('%(asctime)s %(processName)-10s %(name)s %(levelname)-8s %(message)s')
117
+ h.setFormatter(f)
118
+ root.addHandler(h)
119
+ while True:
120
+ try:
121
+ record = queue.get()
122
+ if record is None: # None is a signal to finish
123
+ break
124
+ logger = logging.getLogger(record.name)
125
+ logger.handle(record) # No level or filter logic applied - just do it!
126
+ except Exception:
127
+ import sys, traceback
128
+ print('Whoops! Problem:', file=sys.stderr)
129
+ traceback.print_exc(file=sys.stderr)
130
+
131
+ def worker_configurer(queue):
132
+ h = logging.handlers.QueueHandler(queue) # Just the one handler needed
133
+ root = logging.getLogger()
134
+ root.addHandler(h)
135
+ root.setLevel(logging.DEBUG)
136
+
137
+ def video_stream_similator(video_file, frame_queue, log_queue, video_fps=1.0, play_speed=1.0):
138
+ ############## Start sub process-2: Simulator #############
139
+ worker_configurer(log_queue)
140
+ logger = logging.getLogger(__name__)
141
+ logger.setLevel(logging.DEBUG)
142
+
143
+ vr = VideoReader(video_file)
144
+ sample_fps = round(vr.get_avg_fps() / video_fps)
145
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
146
+ video = vr.get_batch(frame_idx).asnumpy()
147
+ video = np.repeat(video, 6, axis=0)
148
+ length = video.shape[0]
149
+ sleep_time = 1 / video_fps / play_speed
150
+ time_meter = MetricMeter()
151
+ logger.info(f'Simulator Process: start, length = {length}')
152
+ try:
153
+ for start in range(0, length):
154
+ start_time = time.perf_counter()
155
+ end = min(start + 1, length)
156
+ video_clip = video[start:end]
157
+ frame_queue.put(video_clip)
158
+ if start > 0:
159
+ time_meter.add('real_sleep', start_time - last_start)
160
+ logger.info(f'Simulator: write {end - start} frames,\t{start} to {end},\treal_sleep={time_meter["real_sleep"]}')
161
+ if end < length:
162
+ time.sleep(sleep_time)
163
+ last_start = start_time
164
+ frame_queue.put(None)
165
+ except Exception as e:
166
+ print(f'Simulator Exception: {e}')
167
+ time.sleep(0.1)
168
+ logger.info(f'Simulator Process: end')
169
+
170
+ def frame_memory_manager(model, image_processor, frame_queue, log_queue):
171
+ ############## Start sub process-3: Memory Manager #############
172
+ worker_configurer(log_queue)
173
+ logger = logging.getLogger(__name__)
174
+ logger.setLevel(logging.DEBUG)
175
+
176
+ time_meter = MetricMeter()
177
+ logger.info(f'MemManager Process: start')
178
+ frame_cnt = 0
179
+ while True:
180
+ try:
181
+ video_clip = frame_queue.get()
182
+ start_time = time.perf_counter()
183
+ if video_clip is None:
184
+ logger.info(f'MemManager: Ooops, get None')
185
+ break
186
+ logger.info(f'MemManager: get {video_clip.shape[0]} frames from queue')
187
+ image = image_processor.preprocess(video_clip, return_tensors='pt')['pixel_values']
188
+ image = image.unsqueeze(0)
189
+ image_tensor = image.to(model.device, dtype=torch.float16)
190
+ # time_2 = time.perf_counter()
191
+ logger.info(f'MemManager: Start embedding')
192
+ with torch.inference_mode():
193
+ model.embed_video_streaming(image_tensor)
194
+ logger.info(f'MemManager: End embedding')
195
+ end_time = time.perf_counter()
196
+ if frame_cnt > 0:
197
+ time_meter.add('memory_latency', end_time - start_time)
198
+ logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={time_meter["memory_latency"]}')
199
+ else:
200
+ logger.info(f'MemManager: embedded {video_clip.shape[0]} frames,\tidx={frame_cnt},\tmemory_latency={end_time - start_time:.6f}, not logged')
201
+ frame_cnt += video_clip.shape[0]
202
+ except Exception as e:
203
+ print(f'MemManager Exception: {e}')
204
+ time.sleep(0.1)
205
+ logger.info(f'MemManager Process: end')
206
+
207
+ def main(args):
208
+ # torch.multiprocessing.log_to_stderr(logging.DEBUG)
209
+ torch.multiprocessing.set_start_method('spawn', force=True)
210
+ disable_torch_init()
211
+
212
+ log_queue = Queue()
213
+ frame_queue = Queue(maxsize=10)
214
+ processes = []
215
+
216
+ ############## Start listener process #############
217
+ p1 = Process(target=listener, args=(log_queue, args.log_file))
218
+ processes.append(p1)
219
+ p1.start()
220
+
221
+ ############## Start main process #############
222
+ worker_configurer(log_queue)
223
+ logger = logging.getLogger(__name__)
224
+
225
+ model_name = get_model_name_from_path(args.model_path)
226
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
227
+
228
+ logger.info(f'Using conv_mode={args.conv_mode}')
229
+
230
+ conv = conv_templates[args.conv_mode].copy()
231
+ if "mpt" in model_name.lower():
232
+ roles = ('user', 'assistant')
233
+ else:
234
+ roles = conv.roles
235
+
236
+ with Manager() as manager:
237
+ image_tensor = None
238
+ model.use_video_streaming_mode = True
239
+ model.video_embedding_memory = manager.list()
240
+ if args.video_max_frames is not None:
241
+ model.config.video_max_frames = args.video_max_frames
242
+ logger.info(f'Important: set model.config.video_max_frames = {model.config.video_max_frames}')
243
+
244
+ logger.info(f'Important: set video_fps = {args.video_fps}')
245
+ logger.info(f'Important: set play_speed = {args.play_speed}')
246
+
247
+ ############## Start simulator process #############
248
+ p2 = Process(target=video_stream_similator,
249
+ args=(args.video_file, frame_queue, log_queue, args.video_fps, args.play_speed))
250
+ processes.append(p2)
251
+ p2.start()
252
+
253
+ ############## Start memory manager process #############
254
+ p3 = Process(target=frame_memory_manager,
255
+ args=(model, image_processor, frame_queue, log_queue))
256
+ processes.append(p3)
257
+ p3.start()
258
+
259
+ # start QA server
260
+ start_time = datetime.now()
261
+ time_meter = MetricMeter()
262
+ conv_cnt = 0
263
+ while True:
264
+ time.sleep(5)
265
+ try:
266
+ # inp = input(f"{roles[0]}: ")
267
+ inp = "what is in the video?"
268
+ except EOFError:
269
+ inp = ""
270
+ if not inp:
271
+ print("exit...")
272
+ break
273
+
274
+ # 获取当前时间
275
+ now = datetime.now()
276
+ conv_start_time = time.perf_counter()
277
+ # 将当前时间格式化为字符串
278
+ current_time = now.strftime("%H:%M:%S")
279
+ duration = now.timestamp() - start_time.timestamp()
280
+
281
+ # 打印当前时间
282
+ print("\nCurrent Time:", current_time, "Run for:", duration)
283
+ print(f"{roles[0]}: {inp}", end="\n")
284
+ print(f"{roles[1]}: ", end="")
285
+ # every conversation is a new conversation
286
+ conv = conv_templates[args.conv_mode].copy()
287
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
288
+ conv.append_message(conv.roles[0], inp)
289
+
290
+ conv.append_message(conv.roles[1], None)
291
+ prompt = conv.get_prompt()
292
+
293
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
294
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
295
+ keywords = [stop_str]
296
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
297
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
298
+
299
+ llm_start_time = time.perf_counter()
300
+ with torch.inference_mode():
301
+ output_ids = model.generate(
302
+ input_ids,
303
+ images=image_tensor,
304
+ do_sample=True if args.temperature > 0 else False,
305
+ temperature=args.temperature,
306
+ max_new_tokens=args.max_new_tokens,
307
+ streamer=streamer,
308
+ use_cache=True,
309
+ stopping_criteria=[stopping_criteria]
310
+ )
311
+ llm_end_time = time.perf_counter()
312
+
313
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
314
+ conv.messages[-1][-1] = outputs
315
+ conv_end_time = time.perf_counter()
316
+ if conv_cnt > 0:
317
+ time_meter.add('conv_latency', conv_end_time - conv_start_time)
318
+ time_meter.add('llm_latency', llm_end_time - llm_start_time)
319
+ time_meter.add('real_sleep', conv_start_time - last_conv_start_time)
320
+ logger.info(f'CliServer: idx={conv_cnt},\treal_sleep={time_meter["real_sleep"]},\tconv_latency={time_meter["conv_latency"]},\tllm_latency={time_meter["llm_latency"]}')
321
+ else:
322
+ logger.info(f'CliServer: idx={conv_cnt},\tconv_latency={conv_end_time - conv_start_time},\tllm_latency={llm_end_time - llm_start_time}')
323
+ conv_cnt += 1
324
+ last_conv_start_time = conv_start_time
325
+
326
+ for p in processes:
327
+ p.terminate()
328
+ print("All processes finished.")
329
+
330
+
331
+ if __name__ == "__main__":
332
+ parser = argparse.ArgumentParser()
333
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
334
+ parser.add_argument("--model-base", type=str, default=None)
335
+ parser.add_argument("--image-file", type=str, default=None)
336
+ parser.add_argument("--video-file", type=str, default=None)
337
+ parser.add_argument("--device", type=str, default="cuda")
338
+ parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
339
+ parser.add_argument("--temperature", type=float, default=0.2)
340
+ parser.add_argument("--max-new-tokens", type=int, default=512)
341
+ parser.add_argument("--load-8bit", action="store_true")
342
+ parser.add_argument("--load-4bit", action="store_true")
343
+ parser.add_argument("--debug", action="store_true")
344
+
345
+ parser.add_argument("--log-file", type=str, default="tmp_cli.log")
346
+ parser.add_argument("--use_1process", action="store_true")
347
+ parser.add_argument("--video_max_frames", type=int, default=None)
348
+ parser.add_argument("--video_fps", type=float, default=1.0)
349
+ parser.add_argument("--play_speed", type=float, default=1.0)
350
+ args = parser.parse_args()
351
+ main(args)
flash_vstream/serve/demo.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..constants import *
3
+ from ..conversation import conv_templates, SeparatorStyle
4
+ from ..model.builder import load_pretrained_model
5
+ from ..utils import disable_torch_init
6
+ from ..mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
7
+ from PIL import Image
8
+ import os
9
+ from decord import VideoReader, cpu
10
+ import numpy as np
11
+
12
+
13
+ class Chat:
14
+ def __init__(self, model_path, conv_mode="simple", load_8bit=False, load_4bit=False):
15
+ disable_torch_init()
16
+ self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(model_path, None, model_name="ChatUniVi", load_8bit=load_8bit, load_4bit=load_4bit)
17
+
18
+ mm_use_im_start_end = getattr(self.model.config, "mm_use_im_start_end", False)
19
+ mm_use_im_patch_token = getattr(self.model.config, "mm_use_im_patch_token", True)
20
+ if mm_use_im_patch_token:
21
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
22
+ if mm_use_im_start_end:
23
+ self.tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
24
+ self.model.resize_token_embeddings(len(self.tokenizer))
25
+
26
+ vision_tower = self.model.get_vision_tower()
27
+ if not vision_tower.is_loaded:
28
+ vision_tower.load_model()
29
+
30
+ self.image_processor = vision_tower.image_processor
31
+ self.conv_mode = conv_mode
32
+ print(self.model)
33
+
34
+ def get_prompt(self, qs, state):
35
+ state.append_message(state.roles[0], qs)
36
+ state.append_message(state.roles[1], None)
37
+ return state
38
+
39
+ def _get_rawvideo_dec(self, video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224,
40
+ video_framerate=1, s=None, e=None):
41
+ if s is None:
42
+ start_time, end_time = None, None
43
+ else:
44
+ start_time = int(s)
45
+ end_time = int(e)
46
+ start_time = start_time if start_time >= 0. else 0.
47
+ end_time = end_time if end_time >= 0. else 0.
48
+ if start_time > end_time:
49
+ start_time, end_time = end_time, start_time
50
+ elif start_time == end_time:
51
+ end_time = start_time + 1
52
+
53
+ if os.path.exists(video_path):
54
+ vreader = VideoReader(video_path, ctx=cpu(0))
55
+ else:
56
+ print(video_path)
57
+ raise FileNotFoundError
58
+
59
+ fps = vreader.get_avg_fps()
60
+ f_start = 0 if start_time is None else int(start_time * fps)
61
+ f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
62
+ num_frames = f_end - f_start + 1
63
+ if num_frames > 0:
64
+ sample_fps = int(video_framerate)
65
+ t_stride = int(round(float(fps) / sample_fps))
66
+
67
+ all_pos = list(range(f_start, f_end + 1, t_stride))
68
+ if len(all_pos) > max_frames:
69
+ sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
70
+ else:
71
+ sample_pos = all_pos
72
+
73
+ patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
74
+ return patch_images
75
+
76
+ @torch.inference_mode()
77
+ def generate(self, images_tensor: list, prompt: str, first_run: bool, state):
78
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
79
+
80
+ state = self.get_prompt(prompt, state)
81
+ prompt = state.get_prompt()
82
+ print(prompt)
83
+
84
+ images_tensor = torch.stack(images_tensor, dim=0)
85
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
86
+
87
+ temperature = 0.2
88
+ max_new_tokens = 1024
89
+
90
+ stop_str = conv_templates[self.conv_mode].copy().sep if conv_templates[self.conv_mode].copy().sep_style != SeparatorStyle.TWO else \
91
+ conv_templates[self.conv_mode].copy().sep2
92
+ keywords = [stop_str]
93
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
94
+
95
+ with torch.inference_mode():
96
+ output_ids = model.generate(
97
+ input_ids,
98
+ images=images_tensor,
99
+ do_sample=True,
100
+ temperature=temperature,
101
+ num_beams=1,
102
+ max_new_tokens=max_new_tokens,
103
+ use_cache=True,
104
+ stopping_criteria=[stopping_criteria])
105
+
106
+ input_token_len = input_ids.shape[1]
107
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
108
+ if n_diff_input_output > 0:
109
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
110
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
111
+ outputs = outputs.strip()
112
+ if outputs.endswith(stop_str):
113
+ outputs = outputs[:-len(stop_str)]
114
+ outputs = outputs.strip()
115
+
116
+ print('response', outputs)
117
+ return outputs, state
118
+
119
+
120
+
121
+ title_markdown = ("""
122
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
123
+ <a href="https://github.com/PKU-YuanGroup/Chat-UniVi" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
124
+ <img src="https://z1.ax1x.com/2023/11/22/pidlXh4.jpg" alt="Chat-UniVi🚀" style="max-width: 120px; height: auto;">
125
+ </a>
126
+ <div>
127
+ <h1 >Chat-UniVi: Unified Visual Representation Empowers Large Language Models with Image and Video Understanding</h1>
128
+ <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5>
129
+ </div>
130
+ </div>
131
+ <div align="center">
132
+ <div style="display:flex; gap: 0.25rem;" align="center">
133
+ <a href='https://github.com/PKU-YuanGroup/Chat-UniVi'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
134
+ <a href="https://arxiv.org/pdf/2311.08046.pdf"><img src="https://img.shields.io/badge/Arxiv-2311.08046-red"></a>
135
+ <a href='https://github.com/PKU-YuanGroup/Chat-UniVi/stargazers'><img src='https://img.shields.io/github/stars/PKU-YuanGroup/Chat-UniVi.svg?style=social'></a>
136
+ </div>
137
+ </div>
138
+ """)
139
+
140
+ block_css = """
141
+ #buttons button {
142
+ min-width: min(120px,100%);
143
+ }
144
+ """
flash_vstream/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ from typing import Optional, Tuple
4
+ import warnings
5
+
6
+ import torch
7
+
8
+ import transformers
9
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
10
+
11
+ try:
12
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
13
+ except ImportError:
14
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
15
+ from flash_attn.bert_padding import unpad_input, pad_input
16
+
17
+
18
+ def forward(
19
+ self,
20
+ hidden_states: torch.Tensor,
21
+ attention_mask: Optional[torch.Tensor] = None,
22
+ position_ids: Optional[torch.Tensor] = None,
23
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
24
+ output_attentions: bool = False,
25
+ use_cache: bool = False,
26
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
27
+ if output_attentions:
28
+ warnings.warn(
29
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
30
+ )
31
+
32
+ bsz, q_len, _ = hidden_states.size()
33
+
34
+ query_states = (
35
+ self.q_proj(hidden_states)
36
+ .view(bsz, q_len, self.num_heads, self.head_dim)
37
+ .transpose(1, 2)
38
+ )
39
+ key_states = (
40
+ self.k_proj(hidden_states)
41
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
42
+ .transpose(1, 2)
43
+ )
44
+ value_states = (
45
+ self.v_proj(hidden_states)
46
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
47
+ .transpose(1, 2)
48
+ ) # shape: (b, num_heads, s, head_dim)
49
+
50
+ kv_seq_len = key_states.shape[-2]
51
+ if past_key_value is not None:
52
+ kv_seq_len += past_key_value[0].shape[-2]
53
+
54
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55
+ query_states, key_states = apply_rotary_pos_emb(
56
+ query_states, key_states, cos, sin, position_ids
57
+ )
58
+
59
+ if past_key_value is not None:
60
+ # reuse k, v
61
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
62
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
63
+
64
+ past_key_value = (key_states, value_states) if use_cache else None
65
+
66
+ # repeat k/v heads if n_kv_heads < n_heads
67
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
68
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
69
+
70
+ # Transform the data into the format required by flash attention
71
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
72
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
73
+ key_padding_mask = attention_mask
74
+
75
+ if key_padding_mask is None:
76
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
77
+ cu_q_lens = torch.arange(
78
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
79
+ )
80
+ max_s = q_len
81
+ output = flash_attn_unpadded_qkvpacked_func(
82
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
83
+ )
84
+ output = output.view(bsz, q_len, -1)
85
+ else:
86
+ qkv = qkv.reshape(bsz, q_len, -1)
87
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
88
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
89
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
90
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
91
+ )
92
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
93
+ output = pad_input(output_unpad, indices, bsz, q_len)
94
+
95
+ return self.o_proj(output), None, past_key_value
96
+
97
+
98
+ # Disable the transformation of the attention mask in LlamaModel as the flash attention
99
+ # requires the attention mask to be the same as the key_padding_mask
100
+ def _prepare_decoder_attention_mask(
101
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
102
+ ):
103
+ # [bsz, seq_len]
104
+ return attention_mask
105
+
106
+
107
+ def replace_llama_attn_with_flash_attn():
108
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
109
+ if cuda_major < 8:
110
+ warnings.warn(
111
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
112
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
113
+ )
114
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
115
+ _prepare_decoder_attention_mask
116
+ )
117
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
flash_vstream/train/llama_xformers_attn_monkey_patch.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ """
4
+ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
5
+ """
6
+
7
+ import logging
8
+ import math
9
+ from typing import Optional, Tuple
10
+
11
+ import torch
12
+ import transformers.models.llama.modeling_llama
13
+ from torch import nn
14
+
15
+ try:
16
+ import xformers.ops
17
+ except ImportError:
18
+ logging.error("xformers not found! Please install it before trying to use it.")
19
+
20
+
21
+ def replace_llama_attn_with_xformers_attn():
22
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
23
+
24
+
25
+ def xformers_forward(
26
+ self,
27
+ hidden_states: torch.Tensor,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ position_ids: Optional[torch.LongTensor] = None,
30
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
31
+ output_attentions: bool = False,
32
+ use_cache: bool = False,
33
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
34
+ # pylint: disable=duplicate-code
35
+ bsz, q_len, _ = hidden_states.size()
36
+
37
+ query_states = (
38
+ self.q_proj(hidden_states)
39
+ .view(bsz, q_len, self.num_heads, self.head_dim)
40
+ .transpose(1, 2)
41
+ )
42
+ key_states = (
43
+ self.k_proj(hidden_states)
44
+ .view(bsz, q_len, self.num_heads, self.head_dim)
45
+ .transpose(1, 2)
46
+ )
47
+ value_states = (
48
+ self.v_proj(hidden_states)
49
+ .view(bsz, q_len, self.num_heads, self.head_dim)
50
+ .transpose(1, 2)
51
+ )
52
+
53
+ kv_seq_len = key_states.shape[-2]
54
+ if past_key_value is not None:
55
+ kv_seq_len += past_key_value[0].shape[-2]
56
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
57
+ (
58
+ query_states,
59
+ key_states,
60
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
61
+ query_states, key_states, cos, sin, position_ids
62
+ )
63
+ # [bsz, nh, t, hd]
64
+
65
+ if past_key_value is not None:
66
+ # reuse k, v, self_attention
67
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
68
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
69
+
70
+ past_key_value = (key_states, value_states) if use_cache else None
71
+
72
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
73
+ if not output_attentions:
74
+ query_states = query_states.transpose(1, 2)
75
+ key_states = key_states.transpose(1, 2)
76
+ value_states = value_states.transpose(1, 2)
77
+
78
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
79
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
80
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
81
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
82
+ attn_output = xformers.ops.memory_efficient_attention(
83
+ query_states, key_states, value_states, attn_bias=None
84
+ )
85
+ else:
86
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
87
+ attn_output = xformers.ops.memory_efficient_attention(
88
+ query_states,
89
+ key_states,
90
+ value_states,
91
+ attn_bias=xformers.ops.LowerTriangularMask(),
92
+ )
93
+ attn_weights = None
94
+ else:
95
+ attn_weights = torch.matmul(
96
+ query_states, key_states.transpose(2, 3)
97
+ ) / math.sqrt(self.head_dim)
98
+
99
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
100
+ raise ValueError(
101
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
102
+ f" {attn_weights.size()}"
103
+ )
104
+
105
+ if attention_mask is not None:
106
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
107
+ raise ValueError(
108
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
109
+ )
110
+ attn_weights = attn_weights + attention_mask
111
+ attn_weights = torch.max(
112
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
113
+ )
114
+
115
+ # upcast attention to fp32
116
+ attn_weights = nn.functional.softmax(
117
+ attn_weights, dim=-1, dtype=torch.float32
118
+ ).to(query_states.dtype)
119
+ attn_output = torch.matmul(attn_weights, value_states)
120
+
121
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
122
+ raise ValueError(
123
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
124
+ f" {attn_output.size()}"
125
+ )
126
+
127
+ attn_output = attn_output.transpose(1, 2)
128
+
129
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
130
+ attn_output = self.o_proj(attn_output)
131
+ return attn_output, attn_weights, past_key_value
flash_vstream/train/train.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # ------------------------------------------------------------------------
3
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
5
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
6
+ # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ import os
21
+ import copy
22
+ import json
23
+ import torch
24
+ import random
25
+ import logging
26
+ import pathlib
27
+ import transformers
28
+ from dataclasses import dataclass, field
29
+ from typing import Dict, Optional, Sequence, List
30
+
31
+
32
+ from flash_vstream.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
33
+ from torch.utils.data import Dataset
34
+ from flash_vstream.train.vstream_trainer import VStreamTrainer
35
+
36
+ from flash_vstream import conversation as conversation_lib
37
+ from flash_vstream.model import VStreamLlamaForCausalLM, VStreamConfig
38
+ from flash_vstream.mm_utils import tokenizer_image_token
39
+
40
+ from PIL import Image
41
+ from decord import VideoReader
42
+ from safetensors.torch import load_file, save_file
43
+
44
+
45
+ local_rank = None
46
+
47
+
48
+ def rank0_print(*args):
49
+ if local_rank == 0:
50
+ print(*args)
51
+
52
+
53
+ @dataclass
54
+ class ModelArguments:
55
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
56
+ version: Optional[str] = field(default="v0")
57
+ freeze_backbone: bool = field(default=False)
58
+ tune_mm_mlp_adapter: bool = field(default=False)
59
+ vision_tower: Optional[str] = field(default=None)
60
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
61
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
62
+ mm_projector_type: Optional[str] = field(default='linear')
63
+ mm_use_im_start_end: bool = field(default=False)
64
+ mm_use_im_patch_token: bool = field(default=True)
65
+ mm_vision_select_feature: Optional[str] = field(default="patch")
66
+ mm_use_4_vision_tokens: bool = field(default=False)
67
+ compress_type: Optional[str] = field(default=None)
68
+ compress_size: int = field(default=4)
69
+ compress_long_memory_size: int = field(default=1)
70
+ compress_Turing_memory_size: int = field(default=1)
71
+ compress_Turing_hidden_dim: int = field(default=32)
72
+ compress_Turing_update_ratio: float = field(default=0.2)
73
+
74
+
75
+ @dataclass
76
+ class DataArguments:
77
+ data_path: str = field(default=None,
78
+ metadata={"help": "Path to the training data."})
79
+ lazy_preprocess: bool = False
80
+ is_multimodal: bool = False
81
+ image_folder: Optional[str] = field(default=None)
82
+ video_folder: Optional[str] = field(default=None)
83
+ video_fps: Optional[int] = field(default=1)
84
+ video_token: Optional[int] = field(default=2)
85
+ video_max_frames: Optional[int] = field(default=50)
86
+ video_long_memory_length: Optional[int] = field(default=10)
87
+ video_Turing_memory_length: Optional[int] = field(default=10)
88
+ video_short_memory_length: Optional[int] = field(default=10)
89
+ video_current_memory_length: Optional[int] = field(default=1)
90
+ video_sample_type: Optional[str] = field(default='center') # center, uniform, drop, merge
91
+ image_aspect_ratio: str = 'square'
92
+
93
+
94
+ @dataclass
95
+ class TrainingArguments(transformers.TrainingArguments):
96
+ cache_dir: Optional[str] = field(default=None)
97
+ optim: str = field(default="adamw_torch")
98
+ remove_unused_columns: bool = field(default=False)
99
+ freeze_mm_mlp_adapter: bool = field(default=False)
100
+ model_max_length: int = field(
101
+ default=512,
102
+ metadata={
103
+ "help":
104
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
105
+ },
106
+ )
107
+ double_quant: bool = field(
108
+ default=True,
109
+ metadata={"help": "Compress the quantization statistics through double quantization."}
110
+ )
111
+ quant_type: str = field(
112
+ default="nf4",
113
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
114
+ )
115
+ bits: int = field(
116
+ default=16,
117
+ metadata={"help": "How many bits to use."}
118
+ )
119
+ lora_enable: bool = False
120
+ lora_r: int = 64
121
+ lora_alpha: int = 16
122
+ lora_dropout: float = 0.05
123
+ lora_weight_path: str = ""
124
+ lora_bias: str = "none"
125
+ mm_projector_lr: Optional[float] = None
126
+ group_by_modality_length: bool = field(default=False)
127
+
128
+
129
+ def maybe_zero_3(param, ignore_status=False, name=None):
130
+ from deepspeed import zero
131
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
132
+ if hasattr(param, "ds_id"):
133
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
134
+ if not ignore_status:
135
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
136
+ with zero.GatheredParameters([param]):
137
+ param = param.data.detach().cpu().clone()
138
+ else:
139
+ param = param.detach().cpu().clone()
140
+ return param
141
+
142
+
143
+ # Borrowed from peft.utils.get_peft_model_state_dict
144
+ def get_peft_state_maybe_zero_3(named_params, bias):
145
+ if bias == "none":
146
+ to_return = {k: t for k, t in named_params if "lora_" in k}
147
+ elif bias == "all":
148
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
149
+ elif bias == "lora_only":
150
+ to_return = {}
151
+ maybe_lora_bias = {}
152
+ lora_bias_names = set()
153
+ for k, t in named_params:
154
+ if "lora_" in k:
155
+ to_return[k] = t
156
+ bias_name = k.split("lora_")[0] + "bias"
157
+ lora_bias_names.add(bias_name)
158
+ elif "bias" in k:
159
+ maybe_lora_bias[k] = t
160
+ for k, t in maybe_lora_bias:
161
+ if bias_name in lora_bias_names:
162
+ to_return[bias_name] = t
163
+ else:
164
+ raise NotImplementedError
165
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
166
+ return to_return
167
+
168
+
169
+ def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
170
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
171
+ if require_grad_only:
172
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
173
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
174
+ return to_return
175
+
176
+
177
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
178
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
179
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
180
+ return to_return
181
+
182
+
183
+ def find_all_linear_names(model):
184
+ cls = torch.nn.Linear
185
+ lora_module_names = set()
186
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
187
+ for name, module in model.named_modules():
188
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
189
+ continue
190
+ if isinstance(module, cls):
191
+ names = name.split('.')
192
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
193
+
194
+ if 'lm_head' in lora_module_names: # needed for 16-bit
195
+ lora_module_names.remove('lm_head')
196
+ return list(lora_module_names)
197
+
198
+
199
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
200
+ output_dir: str):
201
+ """Collects the state dict and dump to disk."""
202
+
203
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
204
+ # Only save Adapter
205
+ keys_to_match = ['mm_projector']
206
+ if getattr(trainer.args, "use_im_start_end", False):
207
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
208
+
209
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
210
+ trainer.model.config.save_pretrained(output_dir)
211
+
212
+ current_folder = output_dir.split('/')[-1]
213
+ parent_folder = os.path.dirname(output_dir)
214
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
215
+ if current_folder.startswith('checkpoint-'):
216
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
217
+ os.makedirs(mm_projector_folder, exist_ok=True)
218
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
219
+ else:
220
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
221
+ return
222
+
223
+ if trainer.deepspeed:
224
+ torch.cuda.synchronize()
225
+ trainer.save_model(output_dir)
226
+ return
227
+
228
+ state_dict = trainer.model.state_dict()
229
+ if trainer.args.should_save:
230
+ cpu_state_dict = {
231
+ key: value.cpu()
232
+ for key, value in state_dict.items()
233
+ }
234
+ del state_dict
235
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
236
+
237
+
238
+ def smart_tokenizer_and_embedding_resize(
239
+ special_tokens_dict: Dict,
240
+ tokenizer: transformers.PreTrainedTokenizer,
241
+ model: transformers.PreTrainedModel,
242
+ ):
243
+ """Resize tokenizer and embedding.
244
+
245
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
246
+ """
247
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
248
+ model.resize_token_embeddings(len(tokenizer))
249
+
250
+ if num_new_tokens > 0:
251
+ input_embeddings = model.get_input_embeddings().weight.data
252
+ output_embeddings = model.get_output_embeddings().weight.data
253
+
254
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
255
+ dim=0, keepdim=True)
256
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
257
+ dim=0, keepdim=True)
258
+
259
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
260
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
261
+
262
+
263
+ def _tokenize_fn(strings: Sequence[str],
264
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
265
+ """Tokenize a list of strings."""
266
+ tokenized_list = [
267
+ tokenizer(
268
+ text,
269
+ return_tensors="pt",
270
+ padding="longest",
271
+ max_length=tokenizer.model_max_length,
272
+ truncation=True,
273
+ ) for text in strings
274
+ ]
275
+ input_ids = labels = [
276
+ tokenized.input_ids[0] for tokenized in tokenized_list
277
+ ]
278
+ input_ids_lens = labels_lens = [
279
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
280
+ for tokenized in tokenized_list
281
+ ]
282
+ return dict(
283
+ input_ids=input_ids,
284
+ labels=labels,
285
+ input_ids_lens=input_ids_lens,
286
+ labels_lens=labels_lens,
287
+ )
288
+
289
+
290
+ def _mask_targets(target, tokenized_lens, speakers):
291
+ # cur_idx = 0
292
+ cur_idx = tokenized_lens[0]
293
+ tokenized_lens = tokenized_lens[1:]
294
+ target[:cur_idx] = IGNORE_INDEX
295
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
296
+ if speaker == "human":
297
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
298
+ cur_idx += tokenized_len
299
+
300
+
301
+ def _add_speaker_and_signal(header, source, get_conversation=True):
302
+ """Add speaker and start/end signal on each round."""
303
+ BEGIN_SIGNAL = "### "
304
+ END_SIGNAL = "\n"
305
+ conversation = header
306
+ for sentence in source:
307
+ from_str = sentence["from"]
308
+ if from_str.lower() == "human":
309
+ from_str = conversation_lib.default_conversation.roles[0]
310
+ elif from_str.lower() == "gpt":
311
+ from_str = conversation_lib.default_conversation.roles[1]
312
+ else:
313
+ from_str = 'unknown'
314
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
315
+ sentence["value"] + END_SIGNAL)
316
+ if get_conversation:
317
+ conversation += sentence["value"]
318
+ conversation += BEGIN_SIGNAL
319
+ return conversation
320
+
321
+
322
+ def preprocess_multimodal(
323
+ sources: Sequence[str],
324
+ data_args: DataArguments
325
+ ) -> Dict:
326
+ is_multimodal = data_args.is_multimodal
327
+ if not is_multimodal:
328
+ return sources
329
+
330
+ for source in sources:
331
+ for sentence in source:
332
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
333
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
334
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
335
+ sentence['value'] = sentence['value'].strip()
336
+ if "mmtag" in conversation_lib.default_conversation.version:
337
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
338
+ replace_token = DEFAULT_IMAGE_TOKEN
339
+ if data_args.mm_use_im_start_end:
340
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
341
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
342
+
343
+ return sources
344
+
345
+
346
+ def preprocess_llama_2(
347
+ sources,
348
+ tokenizer: transformers.PreTrainedTokenizer,
349
+ has_image: bool = False
350
+ ) -> Dict:
351
+ conv = conversation_lib.default_conversation.copy()
352
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
353
+
354
+ # Apply prompt templates
355
+ conversations = []
356
+ for i, source in enumerate(sources):
357
+ if roles[source[0]["from"]] != conv.roles[0]:
358
+ # Skip the first one if it is not from human
359
+ source = source[1:]
360
+
361
+ conv.messages = []
362
+ for j, sentence in enumerate(source):
363
+ role = roles[sentence["from"]]
364
+ assert role == conv.roles[j % 2], f"{i}"
365
+ conv.append_message(role, sentence["value"])
366
+ conversations.append(conv.get_prompt())
367
+
368
+ # Tokenize conversations
369
+
370
+ if has_image:
371
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
372
+ else:
373
+ input_ids = tokenizer(
374
+ conversations,
375
+ return_tensors="pt",
376
+ padding="longest",
377
+ max_length=tokenizer.model_max_length,
378
+ truncation=True,
379
+ ).input_ids
380
+
381
+ targets = input_ids.clone()
382
+
383
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
384
+
385
+ # Mask targets
386
+ sep = "[/INST] "
387
+ for conversation, target in zip(conversations, targets):
388
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
389
+
390
+ rounds = conversation.split(conv.sep2)
391
+ cur_len = 1
392
+ target[:cur_len] = IGNORE_INDEX
393
+ for i, rou in enumerate(rounds):
394
+ if rou == "":
395
+ break
396
+
397
+ parts = rou.split(sep)
398
+ if len(parts) != 2:
399
+ break
400
+ parts[0] += sep
401
+
402
+ if has_image:
403
+ round_len = len(tokenizer_image_token(rou, tokenizer))
404
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
405
+ else:
406
+ round_len = len(tokenizer(rou).input_ids)
407
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
408
+
409
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
410
+
411
+ cur_len += round_len
412
+ target[cur_len:] = IGNORE_INDEX
413
+
414
+ if cur_len < tokenizer.model_max_length:
415
+ if cur_len != total_len:
416
+ target[:] = IGNORE_INDEX
417
+ print(
418
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
419
+ f" (ignored)"
420
+ )
421
+
422
+ return dict(
423
+ input_ids=input_ids,
424
+ labels=targets,
425
+ )
426
+
427
+
428
+ def preprocess_v1(
429
+ sources,
430
+ tokenizer: transformers.PreTrainedTokenizer,
431
+ has_image: bool = False
432
+ ) -> Dict:
433
+ conv = conversation_lib.default_conversation.copy()
434
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
435
+
436
+ # Apply prompt templates
437
+ conversations = []
438
+ for i, source in enumerate(sources):
439
+ if roles[source[0]["from"]] != conv.roles[0]:
440
+ # Skip the first one if it is not from human
441
+ source = source[1:]
442
+
443
+ conv.messages = []
444
+ for j, sentence in enumerate(source):
445
+ role = roles[sentence["from"]]
446
+ assert role == conv.roles[j % 2], f"{i}"
447
+ conv.append_message(role, sentence["value"])
448
+ conversations.append(conv.get_prompt())
449
+
450
+ # Tokenize conversations
451
+
452
+ if has_image:
453
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
454
+ else:
455
+ input_ids = tokenizer(
456
+ conversations,
457
+ return_tensors="pt",
458
+ padding="longest",
459
+ max_length=tokenizer.model_max_length,
460
+ truncation=True,
461
+ ).input_ids
462
+
463
+ targets = input_ids.clone()
464
+
465
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
466
+
467
+ # Mask targets
468
+ sep = conv.sep + conv.roles[1] + ": "
469
+ for conversation, target in zip(conversations, targets):
470
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
471
+
472
+ rounds = conversation.split(conv.sep2)
473
+ cur_len = 1
474
+ target[:cur_len] = IGNORE_INDEX
475
+ for i, rou in enumerate(rounds):
476
+ if rou == "":
477
+ break
478
+
479
+ parts = rou.split(sep)
480
+ if len(parts) != 2:
481
+ break
482
+ parts[0] += sep
483
+
484
+ if has_image:
485
+ round_len = len(tokenizer_image_token(rou, tokenizer))
486
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
487
+ else:
488
+ round_len = len(tokenizer(rou).input_ids)
489
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
490
+
491
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
492
+
493
+ cur_len += round_len
494
+ target[cur_len:] = IGNORE_INDEX
495
+
496
+ if cur_len < tokenizer.model_max_length:
497
+ if cur_len != total_len:
498
+ target[:] = IGNORE_INDEX
499
+ print(
500
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
501
+ f" (ignored)"
502
+ )
503
+
504
+ return dict(
505
+ input_ids=input_ids,
506
+ labels=targets,
507
+ )
508
+
509
+
510
+ def preprocess_mpt(
511
+ sources,
512
+ tokenizer: transformers.PreTrainedTokenizer,
513
+ ) -> Dict:
514
+ conv = conversation_lib.default_conversation.copy()
515
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
516
+
517
+ # Apply prompt templates
518
+ conversations = []
519
+ for i, source in enumerate(sources):
520
+ if roles[source[0]["from"]] != conv.roles[0]:
521
+ # Skip the first one if it is not from human
522
+ source = source[1:]
523
+
524
+ conv.messages = []
525
+ for j, sentence in enumerate(source):
526
+ role = roles[sentence["from"]]
527
+ assert role == conv.roles[j % 2], f"{i}"
528
+ conv.append_message(role, sentence["value"])
529
+ conversations.append(conv.get_prompt())
530
+
531
+ # Tokenize conversations
532
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
533
+ targets = input_ids.clone()
534
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
535
+
536
+ # Mask targets
537
+ sep = conv.sep + conv.roles[1]
538
+ for conversation, target in zip(conversations, targets):
539
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
540
+
541
+ rounds = conversation.split(conv.sep)
542
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
543
+ for conv_idx in range(3, len(rounds), 2):
544
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
545
+ cur_len = 0
546
+ target[:cur_len] = IGNORE_INDEX
547
+ for i, rou in enumerate(re_rounds):
548
+ if rou == "":
549
+ break
550
+
551
+ parts = rou.split(sep)
552
+ if len(parts) != 2:
553
+ break
554
+ parts[0] += sep
555
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
556
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
557
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
558
+
559
+ cur_len += round_len
560
+ target[cur_len:] = IGNORE_INDEX
561
+
562
+ if cur_len < tokenizer.model_max_length:
563
+ if cur_len != total_len:
564
+ target[:] = IGNORE_INDEX
565
+ print(
566
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
567
+ f" (ignored)"
568
+ )
569
+
570
+ return dict(
571
+ input_ids=input_ids,
572
+ labels=targets,
573
+ )
574
+
575
+
576
+ def preprocess_plain(
577
+ sources: Sequence[str],
578
+ tokenizer: transformers.PreTrainedTokenizer,
579
+ ) -> Dict:
580
+ # add end signal and concatenate together
581
+ conversations = []
582
+ for source in sources:
583
+ assert len(source) == 2
584
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
585
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
586
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
587
+ conversations.append(conversation)
588
+ # tokenize conversations
589
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
590
+ targets = copy.deepcopy(input_ids)
591
+ for target, source in zip(targets, sources):
592
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
593
+ target[:tokenized_len] = IGNORE_INDEX
594
+
595
+ return dict(input_ids=input_ids, labels=targets)
596
+
597
+
598
+ def preprocess(
599
+ sources: Sequence[str],
600
+ tokenizer: transformers.PreTrainedTokenizer,
601
+ has_image: bool = False
602
+ ) -> Dict:
603
+ """
604
+ Given a list of sources, each is a conversation list. This transform:
605
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
606
+ 2. Concatenate conversations together;
607
+ 3. Tokenize the concatenated conversation;
608
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
609
+ """
610
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
611
+ return preprocess_plain(sources, tokenizer)
612
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
613
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
614
+ if conversation_lib.default_conversation.version.startswith("v1"):
615
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
616
+ if conversation_lib.default_conversation.version == "mpt":
617
+ return preprocess_mpt(sources, tokenizer)
618
+ # add end signal and concatenate together
619
+ conversations = []
620
+ for source in sources:
621
+ header = f"{conversation_lib.default_conversation.system}\n\n"
622
+ conversation = _add_speaker_and_signal(header, source)
623
+ conversations.append(conversation)
624
+ # tokenize conversations
625
+ def get_tokenize_len(prompts):
626
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
627
+
628
+ if has_image:
629
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
630
+ else:
631
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
632
+ input_ids = conversations_tokenized["input_ids"]
633
+ targets = copy.deepcopy(input_ids)
634
+ for target, source in zip(targets, sources):
635
+ if has_image:
636
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
637
+ else:
638
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
639
+ speakers = [sentence["from"] for sentence in source]
640
+ _mask_targets(target, tokenized_lens, speakers)
641
+
642
+ return dict(input_ids=input_ids, labels=targets)
643
+
644
+
645
+ class LazySupervisedDataset(Dataset):
646
+ """Dataset for supervised fine-tuning."""
647
+
648
+ def __init__(self, data_path: str,
649
+ tokenizer: transformers.PreTrainedTokenizer,
650
+ data_args: DataArguments):
651
+ super(LazySupervisedDataset, self).__init__()
652
+ list_data_dict = json.load(open(data_path, "r"))
653
+
654
+ rank0_print("Formatting inputs...Skip in lazy mode")
655
+ self.tokenizer = tokenizer
656
+ self.list_data_dict = list_data_dict
657
+ self.data_args = data_args
658
+
659
+ def __len__(self):
660
+ return len(self.list_data_dict)
661
+
662
+ @property
663
+ def lengths(self):
664
+ length_list = []
665
+ for sample in self.list_data_dict:
666
+ img_tokens = 128 if 'image' in sample else 0
667
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
668
+ return length_list
669
+
670
+ @property
671
+ def modality_lengths(self):
672
+ length_list = []
673
+ for sample in self.list_data_dict:
674
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
675
+ cur_len = cur_len if ('image' in sample) or ('video' in sample) else -cur_len
676
+ length_list.append(cur_len)
677
+ return length_list
678
+
679
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
680
+ attempt, max_attempt = 0, 10
681
+ while attempt < max_attempt:
682
+ try:
683
+ sources = self.list_data_dict[i]
684
+ if isinstance(i, int):
685
+ sources = [sources]
686
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
687
+ feature = None
688
+ if 'image' in sources[0]:
689
+ image_file = self.list_data_dict[i]['image']
690
+ image_folder = self.data_args.image_folder
691
+ image_file = os.path.join(image_folder, image_file)
692
+ suffix = image_file.split('.')[-1]
693
+
694
+ if 'features' in image_folder:
695
+ # TODO: load video feature, not supported yet
696
+ image_file = image_file.replace(suffix, 'safetensors')
697
+ if not os.path.exists(image_file):
698
+ print('Image file {} not exist!'.format(image_file))
699
+ feature = load_file(image_file)['feature'].unsqueeze(0)
700
+ sources = preprocess_multimodal(
701
+ copy.deepcopy([e["conversations"] for e in sources]),
702
+ self.data_args)
703
+
704
+ else:
705
+ processor = self.data_args.image_processor
706
+ image = Image.open().convert('RGB')
707
+ if self.data_args.image_aspect_ratio == 'pad':
708
+ def expand2square(pil_img, background_color):
709
+ width, height = pil_img.size
710
+ if width == height:
711
+ return pil_img
712
+ elif width > height:
713
+ result = Image.new(pil_img.mode, (width, width), background_color)
714
+ result.paste(pil_img, (0, (width - height) // 2))
715
+ return result
716
+ else:
717
+ result = Image.new(pil_img.mode, (height, height), background_color)
718
+ result.paste(pil_img, ((height - width) // 2, 0))
719
+ return result
720
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
721
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
722
+ else:
723
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
724
+ sources = preprocess_multimodal(
725
+ copy.deepcopy([e["conversations"] for e in sources]),
726
+ self.data_args)
727
+
728
+ elif 'video' in sources[0]:
729
+ video_file = self.list_data_dict[i]['video']
730
+ video_folder = self.data_args.video_folder
731
+ video_file = os.path.join(video_folder, video_file)
732
+ suffix = video_file.split('.')[-1]
733
+
734
+ if 'features' in video_folder:
735
+ # TODO: load video feature, not supported yet
736
+ video_file = video_file.replace(suffix, 'safetensors')
737
+ if not os.path.exists(video_file):
738
+ print('Video file {} not exist!'.format(video_file))
739
+ feature = load_file(video_file)['feature']
740
+ if 'time' in self.list_data_dict[i]: # breakpoint mode
741
+ if 'time_9dense' in self.list_data_dict[i]:
742
+ tim = self.list_data_dict[i]['time_9dense'] // 4
743
+ start = max(tim - 6 * 9, 0)
744
+ end = min(tim + 6 * 9, feature.shape[0])
745
+ feature = feature[start:end]
746
+ else:
747
+ expansion = 15
748
+ if 'time_9' in self.list_data_dict[i]:
749
+ expansion = 9
750
+ tim = self.list_data_dict[i]['time']
751
+ start = max(tim - expansion, 0)
752
+ end = min(tim + expansion, feature.shape[0])
753
+ feature = feature[start:end]
754
+ elif 'time_9dense' in self.list_data_dict[i]:
755
+ feature = feature[::6]
756
+
757
+ sources = preprocess_multimodal(
758
+ copy.deepcopy([e["conversations"] for e in sources]),
759
+ self.data_args)
760
+ else:
761
+ # directly load video file
762
+ if not os.path.exists(video_file):
763
+ print('File {} not exist!'.format(video_file))
764
+ vr = VideoReader(video_file, num_threads=4)
765
+ sample_fps = round(vr.get_avg_fps()/self.data_args.video_fps)
766
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
767
+ if len(frame_idx) > self.data_args.video_max_frames:
768
+ if self.data_args.video_sample_type == 'center':
769
+ # select middle frames
770
+ start_pos = (len(frame_idx) - self.data_args.video_max_frames) // 2
771
+ frame_idx = frame_idx[start_pos:start_pos + self.data_args.video_max_frames]
772
+ elif self.data_args.video_sample_type == 'uniform':
773
+ scale = 1.0 * len(frame_idx) / self.data_args.video_max_frames
774
+ uniform_idx = [round((i + 1) * scale - 1) for i in range(self.data_args.video_max_frames)]
775
+ frame_idx = [frame_idx[i] for i in uniform_idx]
776
+ elif len(frame_idx) > 18000:
777
+ scale = 1.0 * len(frame_idx) / 180
778
+ uniform_idx = [round((i + 1) * scale - 1) for i in range(180)]
779
+ frame_idx = [frame_idx[i] for i in uniform_idx]
780
+ video = vr.get_batch(frame_idx).asnumpy()
781
+ processor = self.data_args.image_processor
782
+ image = processor.preprocess(video, return_tensors='pt')['pixel_values']
783
+ sources = preprocess_multimodal(
784
+ copy.deepcopy([e["conversations"] for e in sources]),
785
+ self.data_args)
786
+
787
+ else:
788
+ sources = copy.deepcopy([e["conversations"] for e in sources])
789
+ break
790
+ except Exception as e:
791
+ attempt += 1
792
+ print(f"Error in loading id:{i} sample, retrying {attempt} time... Error={e}")
793
+ i = random.randint(0, len(self.list_data_dict)-1)
794
+
795
+ has_image = ('image' in self.list_data_dict[i]) or ('video' in self.list_data_dict[i])
796
+ data_dict = preprocess(
797
+ sources,
798
+ self.tokenizer,
799
+ has_image=has_image)
800
+ if isinstance(i, int):
801
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
802
+ labels=data_dict["labels"][0])
803
+
804
+ # image exist in the data
805
+ if 'image' in self.list_data_dict[i] or 'video' in self.list_data_dict[i]:
806
+ if feature is not None:
807
+ data_dict['feature'] = feature
808
+ else:
809
+ data_dict['image'] = image
810
+ elif self.data_args.is_multimodal:
811
+ # image does not exist in the data, but the model is multimodal
812
+ crop_size = self.data_args.image_processor.crop_size
813
+ patch_size = 14
814
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
815
+ data_dict['feature'] = torch.zeros((crop_size['height'] // patch_size) * (crop_size['width'] // patch_size), self.data_args.mm_hidden_size)
816
+ return data_dict
817
+
818
+
819
+ @dataclass
820
+ class DataCollatorForSupervisedDataset(object):
821
+ """Collate examples for supervised fine-tuning."""
822
+
823
+ tokenizer: transformers.PreTrainedTokenizer
824
+
825
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
826
+ input_ids, labels = tuple([instance[key] for instance in instances]
827
+ for key in ("input_ids", "labels"))
828
+ input_ids = torch.nn.utils.rnn.pad_sequence(
829
+ input_ids,
830
+ batch_first=True,
831
+ padding_value=self.tokenizer.pad_token_id)
832
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
833
+ batch_first=True,
834
+ padding_value=IGNORE_INDEX)
835
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
836
+ labels = labels[:, :self.tokenizer.model_max_length]
837
+ batch = dict(
838
+ input_ids=input_ids,
839
+ labels=labels,
840
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
841
+ )
842
+
843
+ if 'feature' in instances[0]:
844
+ batch['features'] = [instance['feature'] for instance in instances]
845
+ elif 'image' in instances[0]:
846
+ images = [instance['image'] for instance in instances]
847
+ if all(x is not None and x.shape == images[0].shape for x in images):
848
+ batch['images'] = torch.stack(images)
849
+ else:
850
+ batch['images'] = images
851
+
852
+
853
+ return batch
854
+
855
+
856
+ def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
857
+ data_args) -> Dict:
858
+ """Make dataset and collator for supervised fine-tuning."""
859
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
860
+ data_path=data_args.data_path,
861
+ data_args=data_args)
862
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
863
+ return dict(train_dataset=train_dataset,
864
+ eval_dataset=None,
865
+ data_collator=data_collator)
866
+
867
+
868
+ def train():
869
+ global local_rank
870
+
871
+ parser = transformers.HfArgumentParser(
872
+ (ModelArguments, DataArguments, TrainingArguments))
873
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
874
+ local_rank = training_args.local_rank
875
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
876
+
877
+ bnb_model_from_pretrained_args = {}
878
+ if training_args.bits in [4, 8]:
879
+ from transformers import BitsAndBytesConfig
880
+ bnb_model_from_pretrained_args.update(dict(
881
+ device_map={"": training_args.device},
882
+ load_in_4bit=training_args.bits == 4,
883
+ load_in_8bit=training_args.bits == 8,
884
+ quantization_config=BitsAndBytesConfig(
885
+ load_in_4bit=training_args.bits == 4,
886
+ load_in_8bit=training_args.bits == 8,
887
+ llm_int8_skip_modules=["mm_projector"],
888
+ llm_int8_threshold=6.0,
889
+ llm_int8_has_fp16_weight=False,
890
+ bnb_4bit_compute_dtype=compute_dtype,
891
+ bnb_4bit_use_double_quant=training_args.double_quant,
892
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
893
+ )
894
+ ))
895
+
896
+ if model_args.vision_tower is not None:
897
+ model = VStreamLlamaForCausalLM.from_pretrained(
898
+ model_args.model_name_or_path,
899
+ cache_dir=training_args.cache_dir,
900
+ **bnb_model_from_pretrained_args
901
+ )
902
+ else:
903
+ model = transformers.LlamaForCausalLM.from_pretrained(
904
+ model_args.model_name_or_path,
905
+ cache_dir=training_args.cache_dir,
906
+ **bnb_model_from_pretrained_args
907
+ )
908
+ model.config.use_cache = False
909
+
910
+ if model_args.freeze_backbone:
911
+ model.model.requires_grad_(False)
912
+
913
+ if training_args.bits in [4, 8]:
914
+ from peft import prepare_model_for_kbit_training
915
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
916
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
917
+
918
+ if training_args.gradient_checkpointing:
919
+ if hasattr(model, "enable_input_require_grads"):
920
+ model.enable_input_require_grads()
921
+ else:
922
+ def make_inputs_require_grad(module, input, output):
923
+ output.requires_grad_(True)
924
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
925
+
926
+ if training_args.lora_enable:
927
+ from peft import LoraConfig, get_peft_model
928
+ lora_config = LoraConfig(
929
+ r=training_args.lora_r,
930
+ lora_alpha=training_args.lora_alpha,
931
+ target_modules=find_all_linear_names(model),
932
+ lora_dropout=training_args.lora_dropout,
933
+ bias=training_args.lora_bias,
934
+ task_type="CAUSAL_LM",
935
+ )
936
+ if training_args.bits == 16:
937
+ if training_args.bf16:
938
+ model.to(torch.bfloat16)
939
+ if training_args.fp16:
940
+ model.to(torch.float16)
941
+ rank0_print("Adding LoRA adapters...")
942
+ model = get_peft_model(model, lora_config)
943
+
944
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
945
+ model_args.model_name_or_path,
946
+ cache_dir=training_args.cache_dir,
947
+ model_max_length=training_args.model_max_length,
948
+ padding_side="right",
949
+ use_fast=False,
950
+ )
951
+
952
+ if model_args.version == "v0":
953
+ if tokenizer.pad_token is None:
954
+ smart_tokenizer_and_embedding_resize(
955
+ special_tokens_dict=dict(pad_token="[PAD]"),
956
+ tokenizer=tokenizer,
957
+ model=model,
958
+ )
959
+ elif model_args.version == "v0.5":
960
+ tokenizer.pad_token = tokenizer.unk_token
961
+ else:
962
+ tokenizer.pad_token = tokenizer.unk_token
963
+ if model_args.version in conversation_lib.conv_templates:
964
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
965
+ else:
966
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
967
+
968
+ if model_args.vision_tower is not None:
969
+ model_args.video_sample_type = data_args.video_sample_type
970
+ model_args.video_max_frames = data_args.video_max_frames
971
+ model_args.video_long_memory_length = data_args.video_long_memory_length
972
+ model_args.video_Turing_memory_length = data_args.video_Turing_memory_length
973
+ model_args.video_short_memory_length = data_args.video_short_memory_length
974
+ model_args.video_current_memory_length = data_args.video_current_memory_length
975
+ model.get_model().initialize_vision_modules(
976
+ model_args=model_args,
977
+ fsdp=training_args.fsdp
978
+ )
979
+
980
+ vision_tower = model.get_vision_tower()
981
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
982
+
983
+ data_args.image_processor = vision_tower.image_processor
984
+ data_args.is_multimodal = True
985
+
986
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
987
+ model.config.tokenizer_padding_side = tokenizer.padding_side
988
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
989
+
990
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
991
+ if model_args.tune_mm_mlp_adapter:
992
+ model.requires_grad_(False)
993
+ for p in model.get_model().mm_projector.parameters():
994
+ p.requires_grad = True
995
+ for p in model.get_model().attention_model.parameters():
996
+ p.requires_grad = True
997
+
998
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
999
+ if training_args.freeze_mm_mlp_adapter:
1000
+ for p in model.get_model().mm_projector.parameters():
1001
+ p.requires_grad = False
1002
+ for p in model.get_model().attention_model.parameters():
1003
+ p.requires_grad = False
1004
+
1005
+ if training_args.bits in [4, 8]:
1006
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
1007
+
1008
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
1009
+ model.config.mm_projector_lr = training_args.mm_projector_lr
1010
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
1011
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
1012
+ model.config.mm_use_4_vision_tokens = model_args.mm_use_4_vision_tokens
1013
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
1014
+
1015
+ if training_args.bits in [4, 8]:
1016
+ from peft.tuners.lora import LoraLayer
1017
+ for name, module in model.named_modules():
1018
+ if isinstance(module, LoraLayer):
1019
+ if training_args.bf16:
1020
+ module = module.to(torch.bfloat16)
1021
+ if 'norm' in name:
1022
+ module = module.to(torch.float32)
1023
+ if 'lm_head' in name or 'embed_tokens' in name:
1024
+ if hasattr(module, 'weight'):
1025
+ if training_args.bf16 and module.weight.dtype == torch.float32:
1026
+ module = module.to(torch.bfloat16)
1027
+
1028
+ data_args.mm_hidden_size = model.get_vision_tower().hidden_size
1029
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
1030
+ data_args=data_args)
1031
+ trainer = VStreamTrainer(model=model,
1032
+ tokenizer=tokenizer,
1033
+ args=training_args,
1034
+ **data_module)
1035
+
1036
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
1037
+ trainer.train(resume_from_checkpoint=True)
1038
+ else:
1039
+ trainer.train()
1040
+ trainer.save_state()
1041
+
1042
+ model.config.use_cache = True
1043
+
1044
+ if training_args.lora_enable:
1045
+ state_dict = get_peft_state_maybe_zero_3(
1046
+ model.named_parameters(), training_args.lora_bias
1047
+ )
1048
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
1049
+ model.named_parameters()
1050
+ )
1051
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
1052
+ model.config.save_pretrained(training_args.output_dir)
1053
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
1054
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
1055
+ else:
1056
+ safe_save_model_for_hf_trainer(trainer=trainer,
1057
+ output_dir=training_args.output_dir)
1058
+
1059
+
1060
+ if __name__ == "__main__":
1061
+ # random.seed(42)
1062
+ # np.random.seed(42)
1063
+ # torch.manual_seed(42)
1064
+ # torch.cuda.manual_seed(42)
1065
+ # torch.cuda.manual_seed_all(42)
1066
+ # torch.backends.cudnn.deterministic = True
1067
+ # torch.backends.cudnn.benchmark = False
1068
+
1069
+ train()
flash_vstream/train/train_mem.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA.
2
+ # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
3
+ # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
4
+ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
5
+
6
+ # Need to call this before importing transformers.
7
+ from flash_vstream.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
8
+
9
+ replace_llama_attn_with_flash_attn()
10
+
11
+ from flash_vstream.train.train import train
12
+
13
+ if __name__ == "__main__":
14
+ train()
flash_vstream/train/train_xformers.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
4
+
5
+ # Need to call this before importing transformers.
6
+ from flash_vstream.train.llama_xformers_attn_monkey_patch import (
7
+ replace_llama_attn_with_xformers_attn,
8
+ )
9
+
10
+ replace_llama_attn_with_xformers_attn()
11
+
12
+ from flash_vstream.train.train import train
13
+
14
+ if __name__ == "__main__":
15
+ train()
flash_vstream/train/vstream_trainer.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may have been modified by Flash-VStream Authors (Flash-VStream Modifications”). All Flash-VStream Modifications are Copyright 2024 Flash-VStream Authors.
2
+ # ------------------------------------------------------------------------
3
+ # Based on https://github.com/haotian-liu/LLaVA. Below is the original copyright:
4
+ # Copyright 2023 Haotian Liu
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import os
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from torch.utils.data import Sampler
23
+
24
+ from transformers import Trainer
25
+ from transformers.trainer import (
26
+ is_sagemaker_mp_enabled,
27
+ get_parameter_names,
28
+ has_length,
29
+ ALL_LAYERNORM_LAYERS,
30
+ ShardedDDPOption,
31
+ logger,
32
+ )
33
+ from typing import List, Optional
34
+
35
+
36
+ def maybe_zero_3(param, ignore_status=False, name=None):
37
+ from deepspeed import zero
38
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
39
+ if hasattr(param, "ds_id"):
40
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
41
+ if not ignore_status:
42
+ print(name, 'no ignore status')
43
+ with zero.GatheredParameters([param]):
44
+ param = param.data.detach().cpu().clone()
45
+ else:
46
+ param = param.detach().cpu().clone()
47
+ return param
48
+
49
+
50
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
51
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
52
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
53
+ return to_return
54
+
55
+
56
+ def split_to_even_chunks(indices, lengths, num_chunks):
57
+ """
58
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
59
+ """
60
+
61
+ if len(indices) % num_chunks != 0:
62
+ return [indices[i::num_chunks] for i in range(num_chunks)]
63
+
64
+ num_indices_per_chunk = len(indices) // num_chunks
65
+
66
+ chunks = [[] for _ in range(num_chunks)]
67
+ chunks_lengths = [0 for _ in range(num_chunks)]
68
+ for index in indices:
69
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
70
+ chunks[shortest_chunk].append(index)
71
+ chunks_lengths[shortest_chunk] += lengths[index]
72
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
73
+ chunks_lengths[shortest_chunk] = float("inf")
74
+
75
+ return chunks
76
+
77
+
78
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
79
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
80
+ assert all(l != 0 for l in lengths), "Should not have zero length."
81
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
82
+ # all samples are in the same modality
83
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
84
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
85
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
86
+
87
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
88
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
89
+ megabatch_size = world_size * batch_size
90
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
91
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
92
+
93
+ last_mm = mm_megabatches[-1]
94
+ last_lang = lang_megabatches[-1]
95
+ additional_batch = last_mm + last_lang
96
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
97
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
98
+ megabatches = [megabatches[i] for i in megabatch_indices]
99
+
100
+ if len(additional_batch) > 0:
101
+ megabatches.append(sorted(additional_batch))
102
+
103
+ return [i for megabatch in megabatches for i in megabatch]
104
+
105
+
106
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
107
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
108
+ indices = torch.randperm(len(lengths), generator=generator)
109
+ megabatch_size = world_size * batch_size
110
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
111
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
112
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
113
+
114
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
115
+
116
+
117
+ class LengthGroupedSampler(Sampler):
118
+ r"""
119
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
120
+ keeping a bit of randomness.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ batch_size: int,
126
+ world_size: int,
127
+ lengths: Optional[List[int]] = None,
128
+ generator=None,
129
+ group_by_modality: bool = False,
130
+ ):
131
+ if lengths is None:
132
+ raise ValueError("Lengths must be provided.")
133
+
134
+ self.batch_size = batch_size
135
+ self.world_size = world_size
136
+ self.lengths = lengths
137
+ self.generator = generator
138
+ self.group_by_modality = group_by_modality
139
+
140
+ def __len__(self):
141
+ return len(self.lengths)
142
+
143
+ def __iter__(self):
144
+ if self.group_by_modality:
145
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
146
+ else:
147
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
148
+ return iter(indices)
149
+
150
+
151
+ class VStreamTrainer(Trainer):
152
+
153
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
154
+ if self.train_dataset is None or not has_length(self.train_dataset):
155
+ return None
156
+
157
+ if self.args.group_by_modality_length:
158
+ lengths = self.train_dataset.modality_lengths
159
+ return LengthGroupedSampler(
160
+ self.args.train_batch_size,
161
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
162
+ lengths=lengths,
163
+ group_by_modality=True,
164
+ )
165
+ else:
166
+ return super()._get_train_sampler()
167
+
168
+ def create_optimizer(self):
169
+ """
170
+ Setup the optimizer.
171
+
172
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
173
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
174
+ """
175
+ if is_sagemaker_mp_enabled():
176
+ return super().create_optimizer()
177
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
178
+ return super().create_optimizer()
179
+
180
+ opt_model = self.model
181
+
182
+ if self.optimizer is None:
183
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
184
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
185
+ if self.args.mm_projector_lr is not None:
186
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
187
+ optimizer_grouped_parameters = [
188
+ {
189
+ "params": [
190
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
191
+ ],
192
+ "weight_decay": self.args.weight_decay,
193
+ },
194
+ {
195
+ "params": [
196
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
197
+ ],
198
+ "weight_decay": 0.0,
199
+ },
200
+ {
201
+ "params": [
202
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
203
+ ],
204
+ "weight_decay": self.args.weight_decay,
205
+ "lr": self.args.mm_projector_lr,
206
+ },
207
+ {
208
+ "params": [
209
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
210
+ ],
211
+ "weight_decay": 0.0,
212
+ "lr": self.args.mm_projector_lr,
213
+ },
214
+ ]
215
+ else:
216
+ optimizer_grouped_parameters = [
217
+ {
218
+ "params": [
219
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
220
+ ],
221
+ "weight_decay": self.args.weight_decay,
222
+ },
223
+ {
224
+ "params": [
225
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
226
+ ],
227
+ "weight_decay": 0.0,
228
+ },
229
+ ]
230
+
231
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
232
+
233
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
234
+ if optimizer_cls.__name__ == "Adam8bit":
235
+ import bitsandbytes
236
+
237
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
238
+
239
+ skipped = 0
240
+ for module in opt_model.modules():
241
+ if isinstance(module, nn.Embedding):
242
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
243
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
244
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
245
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
246
+ logger.info(f"skipped: {skipped/2**20}M params")
247
+
248
+ return self.optimizer
flash_vstream/utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import datetime
4
+ import logging
5
+ import logging.handlers
6
+ import os
7
+ import sys
8
+
9
+ import requests
10
+
11
+ from flash_vstream.constants import LOGDIR
12
+
13
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
14
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15
+
16
+ handler = None
17
+
18
+
19
+ def build_logger(logger_name, logger_filename):
20
+ global handler
21
+
22
+ formatter = logging.Formatter(
23
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24
+ datefmt="%Y-%m-%d %H:%M:%S",
25
+ )
26
+
27
+ # Set the format of root handlers
28
+ if not logging.getLogger().handlers:
29
+ logging.basicConfig(level=logging.INFO)
30
+ logging.getLogger().handlers[0].setFormatter(formatter)
31
+
32
+ # Redirect stdout and stderr to loggers
33
+ stdout_logger = logging.getLogger("stdout")
34
+ stdout_logger.setLevel(logging.INFO)
35
+ sl = StreamToLogger(stdout_logger, logging.INFO)
36
+ sys.stdout = sl
37
+
38
+ stderr_logger = logging.getLogger("stderr")
39
+ stderr_logger.setLevel(logging.ERROR)
40
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
41
+ sys.stderr = sl
42
+
43
+ # Get logger
44
+ logger = logging.getLogger(logger_name)
45
+ logger.setLevel(logging.INFO)
46
+
47
+ # Add a file handler for all loggers
48
+ if handler is None:
49
+ os.makedirs(LOGDIR, exist_ok=True)
50
+ filename = os.path.join(LOGDIR, logger_filename)
51
+ handler = logging.handlers.TimedRotatingFileHandler(
52
+ filename, when='D', utc=True, encoding='UTF-8')
53
+ handler.setFormatter(formatter)
54
+
55
+ for name, item in logging.root.manager.loggerDict.items():
56
+ if isinstance(item, logging.Logger):
57
+ item.addHandler(handler)
58
+
59
+ return logger
60
+
61
+
62
+ class StreamToLogger(object):
63
+ """
64
+ Fake file-like stream object that redirects writes to a logger instance.
65
+ """
66
+ def __init__(self, logger, log_level=logging.INFO):
67
+ self.terminal = sys.stdout
68
+ self.logger = logger
69
+ self.log_level = log_level
70
+ self.linebuf = ''
71
+
72
+ def __getattr__(self, attr):
73
+ return getattr(self.terminal, attr)
74
+
75
+ def write(self, buf):
76
+ temp_linebuf = self.linebuf + buf
77
+ self.linebuf = ''
78
+ for line in temp_linebuf.splitlines(True):
79
+ # From the io.TextIOWrapper docs:
80
+ # On output, if newline is None, any '\n' characters written
81
+ # are translated to the system default line separator.
82
+ # By default sys.stdout.write() expects '\n' newlines and then
83
+ # translates them so this is still cross platform.
84
+ if line[-1] == '\n':
85
+ self.logger.log(self.log_level, line.rstrip())
86
+ else:
87
+ self.linebuf += line
88
+
89
+ def flush(self):
90
+ if self.linebuf != '':
91
+ self.logger.log(self.log_level, self.linebuf.rstrip())
92
+ self.linebuf = ''
93
+
94
+
95
+ def disable_torch_init():
96
+ """
97
+ Disable the redundant torch default initialization to accelerate model creation.
98
+ """
99
+ import torch
100
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
101
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
102
+
103
+
104
+ def violates_moderation(text):
105
+ """
106
+ Check whether the text violates OpenAI moderation API.
107
+ """
108
+ url = "https://api.openai.com/v1/moderations"
109
+ headers = {"Content-Type": "application/json",
110
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
111
+ text = text.replace("\n", "")
112
+ data = "{" + '"input": ' + f'"{text}"' + "}"
113
+ data = data.encode("utf-8")
114
+ try:
115
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
116
+ flagged = ret.json()["results"][0]["flagged"]
117
+ except requests.exceptions.RequestException as e:
118
+ flagged = False
119
+ except KeyError as e:
120
+ flagged = False
121
+
122
+ return flagged
123
+
124
+
125
+ def pretty_print_semaphore(semaphore):
126
+ if semaphore is None:
127
+ return "None"
128
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
requirements.txt CHANGED
@@ -1 +1 @@
1
- huggingface_hub==0.22.2
 
1
+ huggingface_hub==0.22.2