Choiszt commited on
Commit
c62903f
·
1 Parent(s): c87f81a

Update egogpt

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. egogpt/__pycache__/constants.cpython-310.pyc +0 -0
  2. egogpt/__pycache__/conversation.cpython-310.pyc +0 -0
  3. egogpt/__pycache__/mm_utils.cpython-310.pyc +0 -0
  4. egogpt/__pycache__/utils.cpython-310.pyc +0 -0
  5. egogpt/constants.py +11 -0
  6. egogpt/conversation.py +287 -0
  7. egogpt/mm_utils.py +450 -0
  8. egogpt/model/__init__.py +2 -0
  9. egogpt/model/__pycache__/__init__.cpython-310.pyc +0 -0
  10. egogpt/model/__pycache__/builder.cpython-310.pyc +0 -0
  11. egogpt/model/__pycache__/egogpt_arch.cpython-310.pyc +0 -0
  12. egogpt/model/builder.py +127 -0
  13. egogpt/model/egogpt_arch.py +1357 -0
  14. egogpt/model/language_model/__pycache__/egogpt_llama.cpython-310.pyc +0 -0
  15. egogpt/model/language_model/__pycache__/egogpt_qwen.cpython-310.pyc +0 -0
  16. egogpt/model/language_model/egogpt_llama.py +159 -0
  17. egogpt/model/language_model/egogpt_qwen.py +164 -0
  18. egogpt/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  19. egogpt/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc +0 -0
  20. egogpt/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc +0 -0
  21. egogpt/model/multimodal_encoder/builder.py +36 -0
  22. egogpt/model/multimodal_encoder/clip_encoder.py +235 -0
  23. egogpt/model/multimodal_encoder/siglip_encoder.py +742 -0
  24. egogpt/model/multimodal_projector/__pycache__/builder.cpython-310.pyc +0 -0
  25. egogpt/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc +0 -0
  26. egogpt/model/multimodal_projector/builder.py +68 -0
  27. egogpt/model/multimodal_projector/pooler_projector.py +34 -0
  28. egogpt/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc +0 -0
  29. egogpt/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc +0 -0
  30. egogpt/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc +0 -0
  31. egogpt/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc +0 -0
  32. egogpt/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc +0 -0
  33. egogpt/model/multimodal_resampler/builder.py +34 -0
  34. egogpt/model/multimodal_resampler/masked_drop.py +89 -0
  35. egogpt/model/multimodal_resampler/perceiver.py +172 -0
  36. egogpt/model/multimodal_resampler/qformer.py +1281 -0
  37. egogpt/model/multimodal_resampler/spatial_pool.py +57 -0
  38. egogpt/model/speech_encoder/__pycache__/audio.cpython-310.pyc +0 -0
  39. egogpt/model/speech_encoder/__pycache__/builder.cpython-310.pyc +0 -0
  40. egogpt/model/speech_encoder/__pycache__/decoding.cpython-310.pyc +0 -0
  41. egogpt/model/speech_encoder/__pycache__/model.cpython-310.pyc +0 -0
  42. egogpt/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc +0 -0
  43. egogpt/model/speech_encoder/__pycache__/timing.cpython-310.pyc +0 -0
  44. egogpt/model/speech_encoder/__pycache__/tokenizer.cpython-310.pyc +0 -0
  45. egogpt/model/speech_encoder/__pycache__/transcribe.cpython-310.pyc +0 -0
  46. egogpt/model/speech_encoder/__pycache__/utils.cpython-310.pyc +0 -0
  47. egogpt/model/speech_encoder/audio.py +157 -0
  48. egogpt/model/speech_encoder/builder.py +9 -0
  49. egogpt/model/speech_encoder/decoding.py +826 -0
  50. egogpt/model/speech_encoder/model.py +345 -0
egogpt/__pycache__/constants.cpython-310.pyc ADDED
Binary file (398 Bytes). View file
 
egogpt/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (6.24 kB). View file
 
egogpt/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (13.1 kB). View file
 
egogpt/__pycache__/utils.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
egogpt/constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ SPEECH_TOKEN_INDEX = -200
9
+ DEFAULT_SPEECH_TOKEN = "<speech>"
10
+ IMAGE_TOKEN_INDEX = -300
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
egogpt/conversation.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ import dataclasses
18
+ from enum import Enum, auto
19
+ from io import BytesIO
20
+ from typing import Any, List, Tuple, Union
21
+
22
+ from PIL import Image
23
+
24
+
25
+ class SeparatorStyle(Enum):
26
+ """Different separator style."""
27
+
28
+ TWO = auto()
29
+ PLAIN = auto()
30
+ CHATML = auto()
31
+ LLAMA_2 = auto()
32
+ LLAMA_3 = auto()
33
+ QWEN2 = auto()
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class Conversation:
38
+ """A class that keeps all conversation history."""
39
+
40
+ system: str
41
+ roles: List[str]
42
+ messages: List[List[str]]
43
+ offset: int
44
+ sep_style: SeparatorStyle = SeparatorStyle.PLAIN
45
+ sep: str = "###"
46
+ sep2: str = None
47
+ version: str = "Unknown"
48
+
49
+ tokenizer_id: str = ""
50
+ tokenizer: Any = None
51
+ # Stop criteria (the default one is EOS token)
52
+ stop_str: Union[str, List[str]] = None
53
+ # Stops generation if meeting any token in this list
54
+ stop_token_ids: List[int] = None
55
+
56
+ skip_next: bool = False
57
+
58
+ def get_prompt(self):
59
+ messages = self.messages
60
+
61
+ if self.sep_style == SeparatorStyle.TWO:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system + seps[0]
64
+ for i, (role, message) in enumerate(messages):
65
+ if message:
66
+ if type(message) is tuple:
67
+ message = message[0]
68
+ ret += role + ": " + message + seps[i % 2]
69
+ else:
70
+ ret += role + ":"
71
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
72
+ wrap_sys = (
73
+ lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>"
74
+ if len(msg) > 0
75
+ else msg
76
+ )
77
+ ret = "<|begin_of_text|>" + wrap_sys(self.system)
78
+ for i, (role, message) in enumerate(messages):
79
+ if message:
80
+ if type(message) is tuple:
81
+ message = message[0]
82
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
83
+ ret += message.strip() + self.sep2
84
+ else:
85
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
86
+ return ret
87
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
88
+ wrap_sys = (
89
+ lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
90
+ )
91
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
92
+ ret = ""
93
+
94
+ for i, (role, message) in enumerate(messages):
95
+ if i == 0:
96
+ assert message, "first message should not be none"
97
+ assert role == self.roles[0], "first message should come from user"
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ if i == 0:
102
+ message = wrap_sys(self.system) + message
103
+ if i % 2 == 0:
104
+ message = wrap_inst(message)
105
+ ret += self.sep + message
106
+ else:
107
+ ret += " " + message + " " + self.sep2
108
+ else:
109
+ ret += ""
110
+ ret = ret.lstrip(self.sep)
111
+ elif self.sep_style == SeparatorStyle.PLAIN:
112
+ seps = [self.sep, self.sep2]
113
+ ret = self.system
114
+ for i, (role, message) in enumerate(messages):
115
+ if message:
116
+ if type(message) is tuple:
117
+ message, _, _ = message
118
+ ret += message + seps[i % 2]
119
+ else:
120
+ ret += ""
121
+
122
+ elif self.sep_style == SeparatorStyle.CHATML:
123
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
124
+ for role, message in messages:
125
+ if message:
126
+ if type(message) is tuple:
127
+ message, images = message
128
+ message = "<speech>" * len(images) + message
129
+ ret += role + "\n" + message + self.sep + "\n"
130
+ else:
131
+ ret += role + "\n"
132
+ return ret
133
+ elif self.sep_style == SeparatorStyle.QWEN2:
134
+ start = "<|im_start|>"
135
+ end = "<|im_end|>\n"
136
+ ret = start + "system\n" + self.system + end
137
+ for i, (role, message) in enumerate(messages):
138
+ if message:
139
+ if type(message) is tuple:
140
+ message, _, _ = message
141
+
142
+ if message.endswith("<|endoftext|>"):
143
+ message = message.replace("<|endoftext|>", "")
144
+ ret += start + role + "\n" + message + end + "<|endoftext|>"
145
+ else:
146
+ assert (
147
+ not "<|endoftext|>" in message
148
+ ), f"Invalid message: {message}"
149
+ ret += start + role + "\n" + message + end
150
+ else:
151
+ ret += start + role + "\n"
152
+ else:
153
+ raise ValueError(f"Invalid style: {self.sep_style}")
154
+
155
+ return ret
156
+
157
+ def append_message(self, role, message):
158
+ self.messages.append([role, message])
159
+
160
+ def to_gradio_chatbot(self):
161
+ ret = []
162
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
163
+ if i % 2 == 0:
164
+ if type(msg) is tuple:
165
+ msg, speech = msg
166
+ ret.append([msg, None])
167
+ else:
168
+ ret.append([msg, None])
169
+ else:
170
+ ret[-1][-1] = msg
171
+ return ret
172
+
173
+ def copy(self):
174
+ return Conversation(
175
+ system=self.system,
176
+ roles=self.roles,
177
+ messages=[[x, y] for x, y in self.messages],
178
+ offset=self.offset,
179
+ sep_style=self.sep_style,
180
+ sep=self.sep,
181
+ sep2=self.sep2,
182
+ version=self.version,
183
+ )
184
+
185
+ def dict(self):
186
+ if len(self.get_images()) > 0:
187
+ return {
188
+ "system": self.system,
189
+ "roles": self.roles,
190
+ "messages": [
191
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
192
+ ],
193
+ "offset": self.offset,
194
+ "sep": self.sep,
195
+ "sep2": self.sep2,
196
+ }
197
+ return {
198
+ "system": self.system,
199
+ "roles": self.roles,
200
+ "messages": self.messages,
201
+ "offset": self.offset,
202
+ "sep": self.sep,
203
+ "sep2": self.sep2,
204
+ }
205
+
206
+
207
+ conv_vicuna_v1 = Conversation(
208
+ system="A chat between a curious user and an artificial intelligence assistant. "
209
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
210
+ roles=("USER", "ASSISTANT"),
211
+ version="v1",
212
+ messages=[],
213
+ offset=0,
214
+ sep_style=SeparatorStyle.TWO,
215
+ sep=" ",
216
+ sep2="</s>",
217
+ )
218
+
219
+ conv_llama_2 = Conversation(
220
+ system="You are a helpful language and speech assistant. "
221
+ "You are able to understand the speech content that the user provides, "
222
+ "and assist the user with a variety of tasks using natural language.",
223
+ roles=("USER", "ASSISTANT"),
224
+ version="llama_v2",
225
+ messages=[],
226
+ offset=0,
227
+ sep_style=SeparatorStyle.LLAMA_2,
228
+ sep="<s>",
229
+ sep2="</s>",
230
+ )
231
+
232
+ conv_llama_3 = Conversation(
233
+ system="You are a helpful language and speech assistant. "
234
+ "You are able to understand the speech content that the user provides, "
235
+ "and assist the user with a variety of tasks using natural language.",
236
+ roles=("user", "assistant"),
237
+ version="llama_v3",
238
+ messages=[],
239
+ offset=0,
240
+ sep_style=SeparatorStyle.LLAMA_3,
241
+ sep="",
242
+ sep2="<|eot_id|>",
243
+ )
244
+
245
+
246
+ conv_qwen_v1 = Conversation(
247
+ system="You are a helpful assistant.",
248
+ roles=("user", "assistant"),
249
+ version="v1",
250
+ messages=(),
251
+ offset=0,
252
+ sep_style=SeparatorStyle.QWEN2,
253
+ )
254
+
255
+ conv_plain = Conversation(
256
+ system="",
257
+ roles=("", ""),
258
+ messages=(),
259
+ offset=0,
260
+ sep_style=SeparatorStyle.PLAIN,
261
+ sep="</s>",
262
+ )
263
+
264
+ conv_qwen = Conversation(
265
+ system="""<|im_start|>system
266
+ You are a helpful assistant.""",
267
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
268
+ version="qwen",
269
+ messages=[],
270
+ offset=0,
271
+ sep_style=SeparatorStyle.CHATML,
272
+ sep="<|im_end|>",
273
+ )
274
+
275
+ default_conversation = conv_llama_3
276
+ conv_templates = {
277
+ "v1": conv_vicuna_v1,
278
+ "plain": conv_plain,
279
+ "llama_2": conv_llama_2,
280
+ "llama_3": conv_llama_3,
281
+ "v1_qwen2": conv_qwen_v1,
282
+ "qwen_1_5": conv_qwen,
283
+ }
284
+
285
+
286
+ if __name__ == "__main__":
287
+ print(default_conversation.get_prompt())
egogpt/mm_utils.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import base64
3
+ import math
4
+ import re
5
+ from io import BytesIO
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import StoppingCriteria
10
+
11
+
12
+ def resize_and_center_crop(image, shortest_edge_length):
13
+ # Calculate new dimensions and resize
14
+ aspect_ratio = float(image.width) / float(image.height)
15
+ if aspect_ratio > 1:
16
+ new_width = int(shortest_edge_length * aspect_ratio)
17
+ new_height = shortest_edge_length
18
+ else:
19
+ new_width = shortest_edge_length
20
+ new_height = int(shortest_edge_length / aspect_ratio)
21
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
22
+
23
+ # Calculate the position and perform the center crop
24
+ left = (new_width - shortest_edge_length) / 2
25
+ top = (new_height - shortest_edge_length) / 2
26
+ right = (new_width + shortest_edge_length) / 2
27
+ bottom = (new_height + shortest_edge_length) / 2
28
+ cropped_image = resized_image.crop((left, top, right, bottom))
29
+
30
+ return cropped_image
31
+
32
+
33
+ def auto_pad_images(image, grid_params):
34
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
35
+ assert len(grid_params) > 0, "Grid parameters should not be empty"
36
+
37
+ # Step 1: Calculate and find the closest aspect ratio
38
+ input_width, input_height = image.size
39
+ input_aspect_ratio = input_width / input_height
40
+ candidate_resolutions = [(w / h, w, h) for w in grid_params for h in grid_params]
41
+ closest_aspect_ratio = min(
42
+ candidate_resolutions, key=lambda x: abs(input_aspect_ratio - x[0])
43
+ )
44
+
45
+ candidate_resolutions = [
46
+ (x[1], x[2])
47
+ for x in candidate_resolutions
48
+ if abs(x[0] - closest_aspect_ratio[0]) < 1e-3
49
+ ]
50
+
51
+ target_resolution = min(
52
+ candidate_resolutions,
53
+ key=lambda res: abs(max(input_width, input_height) / max(res) - 1),
54
+ )
55
+
56
+ resize_width, resize_height = target_resolution
57
+ if input_width > input_height:
58
+ resize_height = int(resize_width / input_aspect_ratio)
59
+ else:
60
+ resize_width = int(resize_height * input_aspect_ratio)
61
+ resized_image = image.resize((resize_width, resize_height), Image.ANTIALIAS)
62
+
63
+ # Step 5: Pad the resized image if necessary to match the target resolution
64
+ pad_width = target_resolution[0] - resize_width
65
+ pad_height = target_resolution[1] - resize_height
66
+ padded_image = Image.new("RGB", target_resolution, color=(0, 0, 0))
67
+ padded_image.paste(resized_image, (pad_width // 2, pad_height // 2))
68
+
69
+ return padded_image
70
+
71
+
72
+ def extract_patches(image, patch_size, overlap_ratio):
73
+ assert isinstance(image, Image.Image), "Input should be a Pillow Image"
74
+ assert patch_size > 0, "Patch size should be greater than 0"
75
+ assert 0 <= overlap_ratio < 1, "Overlap ratio should be between 0 and 1"
76
+
77
+ W, H = image.size
78
+ patches = []
79
+
80
+ stride = int(patch_size * (1 - overlap_ratio))
81
+
82
+ num_patches_y = (H - patch_size) // stride + 1
83
+ num_patches_x = (W - patch_size) // stride + 1
84
+
85
+ y_start = (H - (num_patches_y - 1) * stride - patch_size) // 2
86
+ x_start = (W - (num_patches_x - 1) * stride - patch_size) // 2
87
+
88
+ for y in range(y_start, y_start + num_patches_y * stride, stride):
89
+ for x in range(x_start, x_start + num_patches_x * stride, stride):
90
+ patch = image.crop((x, y, x + patch_size, y + patch_size))
91
+ patches.append(patch)
92
+
93
+ return patches
94
+
95
+
96
+ def process_highres_image_crop_split(image, data_args, processor=None):
97
+ crop_resolution = data_args.image_crop_resolution
98
+ split_resolution = data_args.image_split_resolution
99
+ if processor is None:
100
+ processor = data_args.image_processor
101
+ image_crop = resize_and_center_crop(image, crop_resolution)
102
+ image_patches = extract_patches(
103
+ image_crop, patch_size=split_resolution, overlap_ratio=0
104
+ )
105
+ image_patches = [
106
+ processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
107
+ for image_patch in image_patches
108
+ ]
109
+ return torch.stack(image_patches, dim=0)
110
+
111
+
112
+ def process_highres_image(image, processor, grid_pinpoints):
113
+ grid_params = [int(x) for x in grid_pinpoints.split(",")]
114
+ width_height = max(image.size)
115
+ fit_grid_params = [x for x in grid_params if x >= width_height]
116
+ if len(fit_grid_params) == 0:
117
+ select_size = max(grid_params)
118
+ else:
119
+ select_size = min(fit_grid_params)
120
+ # FIXME: always select the 448
121
+ select_size = max(grid_params)
122
+ image_padded = expand2square(
123
+ image, tuple(int(x * 255) for x in processor.image_mean)
124
+ )
125
+
126
+ # FIXME: this seems to be a bug that it always resizes instead of padding
127
+ image_original_resize = image.resize(
128
+ (processor.size["shortest_edge"], processor.size["shortest_edge"])
129
+ )
130
+ image_padded = image_padded.resize((select_size, select_size))
131
+ image_patches = extract_patches(
132
+ image_padded, patch_size=processor.size["shortest_edge"], overlap_ratio=0
133
+ )
134
+ image_patches = [image_original_resize] + image_patches
135
+ image_patches = [
136
+ processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
137
+ for image_patch in image_patches
138
+ ]
139
+ return torch.stack(image_patches, dim=0)
140
+
141
+
142
+ def select_best_resolution(original_size, possible_resolutions):
143
+ """
144
+ Selects the best resolution from a list of possible resolutions based on the original size.
145
+
146
+ Args:
147
+ original_size (tuple): The original size of the image in the format (width, height).
148
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
149
+
150
+ Returns:
151
+ tuple: The best fit resolution in the format (width, height).
152
+ """
153
+ original_width, original_height = original_size
154
+ best_fit = None
155
+ max_effective_resolution = 0
156
+ min_wasted_resolution = float("inf")
157
+
158
+ for width, height in possible_resolutions:
159
+ # Calculate the downscaled size to keep the aspect ratio
160
+ scale = min(width / original_width, height / original_height)
161
+ downscaled_width, downscaled_height = int(original_width * scale), int(
162
+ original_height * scale
163
+ )
164
+
165
+ # Calculate effective and wasted resolutions
166
+ effective_resolution = min(
167
+ downscaled_width * downscaled_height, original_width * original_height
168
+ )
169
+ wasted_resolution = (width * height) - effective_resolution
170
+
171
+ if effective_resolution > max_effective_resolution or (
172
+ effective_resolution == max_effective_resolution
173
+ and wasted_resolution < min_wasted_resolution
174
+ ):
175
+ max_effective_resolution = effective_resolution
176
+ min_wasted_resolution = wasted_resolution
177
+ best_fit = (width, height)
178
+
179
+ return best_fit
180
+
181
+
182
+ def resize_and_pad_image(image, target_resolution):
183
+ """
184
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
185
+
186
+ Args:
187
+ image (PIL.Image.Image): The input image.
188
+ target_resolution (tuple): The target resolution (width, height) of the image.
189
+
190
+ Returns:
191
+ PIL.Image.Image: The resized and padded image.
192
+ """
193
+ original_width, original_height = image.size
194
+ target_width, target_height = target_resolution
195
+
196
+ # Determine which dimension (width or height) to fill
197
+ scale_w = target_width / original_width
198
+ scale_h = target_height / original_height
199
+
200
+ if scale_w < scale_h:
201
+ # Width will be filled completely
202
+ new_width = target_width
203
+ new_height = min(math.ceil(original_height * scale_w), target_height)
204
+ else:
205
+ # Height will be filled completely
206
+ new_height = target_height
207
+ new_width = min(math.ceil(original_width * scale_h), target_width)
208
+
209
+ # Resize the image
210
+ resized_image = image.resize((new_width, new_height))
211
+
212
+ # Create a new image with the target size and paste the resized image onto it
213
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
214
+ paste_x = (target_width - new_width) // 2
215
+ paste_y = (target_height - new_height) // 2
216
+ new_image.paste(resized_image, (paste_x, paste_y))
217
+
218
+ return new_image
219
+
220
+
221
+ def divide_to_patches(image, patch_size):
222
+ """
223
+ Divides an image into patches of a specified size.
224
+
225
+ Args:
226
+ image (PIL.Image.Image): The input image.
227
+ patch_size (int): The size of each patch.
228
+
229
+ Returns:
230
+ list: A list of PIL.Image.Image objects representing the patches.
231
+ """
232
+ patches = []
233
+ width, height = image.size
234
+ for i in range(0, height, patch_size):
235
+ for j in range(0, width, patch_size):
236
+ box = (j, i, j + patch_size, i + patch_size)
237
+ patch = image.crop(box)
238
+ patches.append(patch)
239
+
240
+ return patches
241
+
242
+
243
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
244
+ """
245
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
246
+
247
+ Args:
248
+ image_size (tuple): The size of the input image in the format (width, height).
249
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
250
+ patch_size (int): The size of each image patch.
251
+
252
+ Returns:
253
+ tuple: The shape of the image patch grid in the format (width, height).
254
+ """
255
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
256
+ assert patch_size in [
257
+ 224,
258
+ 336,
259
+ 384,
260
+ 448,
261
+ 512,
262
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
263
+ # Use regex to extract the range from the input string
264
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
265
+ range_start = tuple(map(int, matches[0]))
266
+ range_end = tuple(map(int, matches[-1]))
267
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
268
+ grid_pinpoints = [
269
+ (i, j)
270
+ for i in range(range_start[0], range_end[0] + 1)
271
+ for j in range(range_start[1], range_end[1] + 1)
272
+ ]
273
+ # Multiply all elements by patch_size
274
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
275
+ if type(grid_pinpoints) is list:
276
+ possible_resolutions = grid_pinpoints
277
+ else:
278
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
279
+ width, height = select_best_resolution(image_size, possible_resolutions)
280
+ return width // patch_size, height // patch_size
281
+
282
+
283
+ def process_anyres_image(image, processor, grid_pinpoints):
284
+ """
285
+ Process an image with variable resolutions.
286
+
287
+ Args:
288
+ image (PIL.Image.Image): The input image to be processed.
289
+ processor: The image processor object.
290
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
291
+
292
+ Returns:
293
+ torch.Tensor: A tensor containing the processed image patches.
294
+ """
295
+ # Convert grid_pinpoints from string to list
296
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
297
+ try:
298
+ patch_size = processor.size[0]
299
+ except Exception as e:
300
+ patch_size = processor.size["shortest_edge"]
301
+ assert patch_size in [
302
+ 224,
303
+ 336,
304
+ 384,
305
+ 448,
306
+ 512,
307
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
308
+ # Use regex to extract the range from the input string
309
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
310
+ range_start = tuple(map(int, matches[0]))
311
+ range_end = tuple(map(int, matches[-1]))
312
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
313
+ grid_pinpoints = [
314
+ (i, j)
315
+ for i in range(range_start[0], range_end[0] + 1)
316
+ for j in range(range_start[1], range_end[1] + 1)
317
+ ]
318
+ # Multiply all elements by patch_size
319
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
320
+
321
+ if type(grid_pinpoints) is list:
322
+ possible_resolutions = grid_pinpoints
323
+ else:
324
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
325
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
326
+ image_padded = resize_and_pad_image(image, best_resolution)
327
+
328
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
329
+
330
+ # FIXME: this seems to be a bug that it resizes instead of pad.
331
+ # but to keep it consistent with previous, i will keep it as it is
332
+ # TODO: uncomment below to ablate with the padding
333
+ if isinstance(processor.size, dict):
334
+ shortest_edge = processor.size["shortest_edge"]
335
+ else:
336
+ shortest_edge = min(processor.size)
337
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
338
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
339
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
340
+
341
+ image_patches = [image_original_resize] + patches
342
+ image_patches = [
343
+ processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
344
+ for image_patch in image_patches
345
+ ]
346
+ return torch.stack(image_patches, dim=0)
347
+
348
+
349
+ def load_image_from_base64(image):
350
+ return Image.open(BytesIO(base64.b64decode(image)))
351
+
352
+
353
+ def expand2square(pil_img, background_color):
354
+ width, height = pil_img.size
355
+ if width == height:
356
+ return pil_img
357
+ elif width > height:
358
+ result = Image.new(pil_img.mode, (width, width), background_color)
359
+ result.paste(pil_img, (0, (width - height) // 2))
360
+ return result
361
+ else:
362
+ result = Image.new(pil_img.mode, (height, height), background_color)
363
+ result.paste(pil_img, ((height - width) // 2, 0))
364
+ return result
365
+
366
+
367
+ def process_images(images, image_processor, model_cfg):
368
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
369
+ new_images = []
370
+ try:
371
+ image = images[0].convert("RGB")
372
+ except Exception as e:
373
+ print(f"Failed to open image {images[0]}. Exception:", e)
374
+ raise e
375
+
376
+ image_sizes = image.size
377
+ if image_aspect_ratio == "highres":
378
+ for image in images:
379
+ image = process_highres_image(
380
+ image, image_processor, model_cfg.image_grid_pinpoints
381
+ )
382
+ new_images.append(image)
383
+ elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
384
+ for image in images:
385
+ image = process_anyres_image(
386
+ image, image_processor, model_cfg.image_grid_pinpoints
387
+ )
388
+ new_images.append(image)
389
+ elif image_aspect_ratio == "crop_split":
390
+ for image in images:
391
+ image = process_highres_image_crop_split(image, model_cfg, image_processor)
392
+ new_images.append(image)
393
+ elif image_aspect_ratio == "pad":
394
+ for image in images:
395
+ image = expand2square(
396
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
397
+ )
398
+ image = image_processor.preprocess(image, return_tensors="pt")[
399
+ "pixel_values"
400
+ ][0]
401
+ new_images.append(image)
402
+ else:
403
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
404
+ if all(x.shape == new_images[0].shape for x in new_images):
405
+ new_images = torch.stack(new_images, dim=0)
406
+ return new_images
407
+
408
+
409
+ def get_model_name_from_path(model_path):
410
+ model_path = model_path.strip("/")
411
+ model_paths = model_path.split("/")
412
+ if model_paths[-1].startswith("checkpoint-"):
413
+ return model_paths[-2] + "_" + model_paths[-1]
414
+ else:
415
+ return model_paths[-1]
416
+
417
+
418
+ class KeywordsStoppingCriteria(StoppingCriteria):
419
+ def __init__(self, keywords, tokenizer, input_ids):
420
+ self.keywords = keywords
421
+ self.keyword_ids = []
422
+ for keyword in keywords:
423
+ cur_keyword_ids = tokenizer(keyword).input_ids
424
+ if (
425
+ len(cur_keyword_ids) > 1
426
+ and cur_keyword_ids[0] == tokenizer.bos_token_id
427
+ ):
428
+ cur_keyword_ids = cur_keyword_ids[1:]
429
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
430
+ self.tokenizer = tokenizer
431
+ self.start_len = input_ids.shape[1]
432
+
433
+ def __call__(
434
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
435
+ ) -> bool:
436
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
437
+ offset = min(output_ids.shape[1] - self.start_len, 3)
438
+ self.keyword_ids = [
439
+ keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
440
+ ]
441
+ for keyword_id in self.keyword_ids:
442
+ if output_ids[0, -keyword_id.shape[0] :] == keyword_id:
443
+ return True
444
+ outputs = self.tokenizer.batch_decode(
445
+ output_ids[:, -offset:], skip_special_tokens=True
446
+ )[0]
447
+ for keyword in self.keywords:
448
+ if keyword in outputs:
449
+ return True
450
+ return False
egogpt/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.egogpt_llama import EgoGPTConfig, EgoGPTLlamaForCausalLM
2
+ from .language_model.egogpt_qwen import EgoGPTConfigQwen, EgoGPTQwenForCausalLM
egogpt/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (343 Bytes). View file
 
egogpt/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.9 kB). View file
 
egogpt/model/__pycache__/egogpt_arch.cpython-310.pyc ADDED
Binary file (23.1 kB). View file
 
egogpt/model/builder.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import shutil
18
+ import warnings
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from transformers import (
23
+ AutoConfig,
24
+ AutoModelForCausalLM,
25
+ AutoTokenizer,
26
+ BitsAndBytesConfig,
27
+ )
28
+
29
+ from egogpt.model import *
30
+ from egogpt.model.speech_encoder.builder import build_speech_encoder
31
+
32
+
33
+ def load_pretrained_model(
34
+ model_path,
35
+ model_base=None,
36
+ is_lora=False,
37
+ load_8bit=False,
38
+ load_4bit=False,
39
+ device="cuda",
40
+ use_flash_attn=False,
41
+ **kwargs,
42
+ ):
43
+ # if dist.is_available() and not dist.is_initialized():
44
+ # dist.init_process_group(backend='nccl',init_method='env://')
45
+ if load_8bit:
46
+ kwargs["load_in_8bit"] = True
47
+ elif load_4bit:
48
+ kwargs["load_in_4bit"] = True
49
+ kwargs["quantization_config"] = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_compute_dtype=torch.float16,
52
+ bnb_4bit_use_double_quant=True,
53
+ bnb_4bit_quant_type="nf4",
54
+ )
55
+ else:
56
+ kwargs["torch_dtype"] = torch.float16
57
+
58
+ if use_flash_attn:
59
+ kwargs["attn_implementation"] = "flash_attention_2"
60
+
61
+ model_cls = EgoGPTQwenForCausalLM
62
+
63
+ # Load EgoGPT model
64
+ if is_lora:
65
+ assert model_base is not None, "model_base is required for LoRA models."
66
+ from egogpt.model.language_model.egogpt_llama import EgoGPTConfig
67
+
68
+ lora_cfg_pretrained = EgoGPTConfig.from_pretrained(model_path)
69
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
70
+ print("Loading EgoGPT from base model...")
71
+ model = model_cls.from_pretrained(
72
+ model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs
73
+ )
74
+ print("Loading additional EgoGPT weights...")
75
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
76
+ non_lora_trainables = torch.load(
77
+ os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu"
78
+ )
79
+ non_lora_trainables = {
80
+ (k[11:] if k.startswith("base_model.") else k): v
81
+ for k, v in non_lora_trainables.items()
82
+ }
83
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
84
+ non_lora_trainables = {
85
+ (k[6:] if k.startswith("model.") else k): v
86
+ for k, v in non_lora_trainables.items()
87
+ }
88
+ model.load_state_dict(non_lora_trainables, strict=False)
89
+
90
+ from peft import PeftModel
91
+
92
+ print("Loading LoRA weights...")
93
+ model = PeftModel.from_pretrained(model, model_path)
94
+ print("Merging LoRA weights...")
95
+ model = model.merge_and_unload()
96
+ print("Model is loaded...")
97
+ elif model_base is not None:
98
+ print("Loading EgoGPT from base model...")
99
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
100
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
101
+ model = model_cls.from_pretrained(
102
+ model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs
103
+ )
104
+
105
+ speech_projector_weights = torch.load(
106
+ os.path.join(model_path, "speech_projector.bin"), map_location="cpu"
107
+ )
108
+ speech_projector_weights = {
109
+ k: v.to(torch.float16) for k, v in speech_projector_weights.items()
110
+ }
111
+ model.load_state_dict(speech_projector_weights, strict=False)
112
+ model = model.to(device=device)
113
+ else:
114
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
115
+ model = model_cls.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
116
+ model = model.to(device=device)
117
+
118
+ context_len = 4096
119
+ # model.get_model().speech_encoder = build_speech_encoder(model.config)
120
+ # model.get_model().speech_encoder.to(device=device, dtype=torch.float16)
121
+
122
+ # if hasattr(model.config, "max_sequence_length"):
123
+ # context_len = model.config.max_sequence_length
124
+ # else:
125
+ # context_len = 2048
126
+
127
+ return tokenizer, model, context_len
egogpt/model/egogpt_arch.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ import re
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from egogpt.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX
24
+ from egogpt.mm_utils import get_anyres_image_grid_shape
25
+ from egogpt.utils import lengths_to_padding_mask, rank0_print, rank_print
26
+
27
+ from .multimodal_encoder.builder import build_vision_tower
28
+ from .multimodal_projector.builder import build_vision_projector
29
+ from .multimodal_resampler.builder import build_vision_resampler
30
+ from .speech_encoder.builder import build_speech_encoder
31
+ from .speech_projector.builder import build_speech_projector
32
+
33
+
34
+ class EgoGPTMetaModel:
35
+ def __init__(self, config):
36
+ super(EgoGPTMetaModel, self).__init__(config)
37
+
38
+ if hasattr(config, "mm_vision_tower"):
39
+ delay_load = getattr(config, "delay_load", False)
40
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
41
+ self.vision_resampler = build_vision_resampler(
42
+ config, vision_tower=self.vision_tower
43
+ )
44
+ self.mm_projector = build_vision_projector(
45
+ config, vision_cfg=self.vision_tower.config
46
+ )
47
+
48
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
49
+ self.image_newline = nn.Parameter(
50
+ torch.empty(config.hidden_size, dtype=self.dtype)
51
+ )
52
+
53
+ if hasattr(config, "speech_encoder"):
54
+ self.speech_encoder = build_speech_encoder(config)
55
+ self.speech_projector = build_speech_projector(config)
56
+
57
+ def get_vision_tower(self):
58
+ vision_tower = getattr(self, "vision_tower", None)
59
+ if type(vision_tower) is list:
60
+ vision_tower = vision_tower[0]
61
+ return vision_tower
62
+
63
+ def get_speech_encoder(self):
64
+ speech_encoder = getattr(self, "speech_encoder", None)
65
+ if type(speech_encoder) is list:
66
+ speech_encoder = speech_encoder[0]
67
+ return speech_encoder
68
+
69
+ def initialize_vision_modules(self, model_args, fsdp=None):
70
+ vision_tower = model_args.vision_tower
71
+ mm_vision_select_layer = model_args.mm_vision_select_layer
72
+ mm_vision_select_feature = model_args.mm_vision_select_feature
73
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
74
+ mm_patch_merge_type = model_args.mm_patch_merge_type
75
+
76
+ self.config.mm_vision_tower = vision_tower
77
+ self.config.vision_tower_pretrained = getattr(
78
+ model_args, "vision_tower_pretrained", ""
79
+ )
80
+
81
+ if self.get_vision_tower() is None:
82
+ vision_tower = build_vision_tower(model_args)
83
+ vision_resampler = build_vision_resampler(
84
+ model_args, vision_tower=vision_tower
85
+ )
86
+ for k, v in vision_resampler.config.items():
87
+ setattr(self.config, k, v)
88
+
89
+ if fsdp is not None and len(fsdp) > 0:
90
+ self.vision_tower = [vision_tower]
91
+ self.vision_resampler = [vision_resampler]
92
+ else:
93
+ self.vision_tower = vision_tower
94
+ self.vision_resampler = vision_resampler
95
+ else:
96
+ if fsdp is not None and len(fsdp) > 0:
97
+ vision_resampler = self.vision_resampler[0]
98
+ vision_tower = self.vision_tower[0]
99
+ else:
100
+ vision_resampler = self.vision_resampler
101
+ vision_tower = self.vision_tower
102
+ vision_tower.load_model()
103
+
104
+ # In case it is frozen by LoRA
105
+ for p in self.vision_resampler.parameters():
106
+ p.requires_grad = True
107
+
108
+ self.config.use_mm_proj = True
109
+ self.config.mm_projector_type = getattr(
110
+ model_args, "mm_projector_type", "linear"
111
+ )
112
+ self.config.mm_hidden_size = getattr(
113
+ vision_resampler, "hidden_size", vision_tower.hidden_size
114
+ )
115
+ self.config.mm_vision_select_layer = mm_vision_select_layer
116
+ self.config.mm_vision_select_feature = mm_vision_select_feature
117
+ self.config.mm_patch_merge_type = mm_patch_merge_type
118
+
119
+ if not hasattr(self.config, "add_faster_video"):
120
+ if model_args.add_faster_video:
121
+ embed_std = 1 / torch.sqrt(
122
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
123
+ )
124
+ self.faster_token = nn.Parameter(
125
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
126
+ )
127
+
128
+ if getattr(self, "mm_projector", None) is None:
129
+ self.mm_projector = build_vision_projector(
130
+ self.config, vision_cfg=vision_tower.config
131
+ )
132
+
133
+ if "unpad" in mm_patch_merge_type:
134
+ embed_std = 1 / torch.sqrt(
135
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
136
+ )
137
+ self.image_newline = nn.Parameter(
138
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
139
+ )
140
+ else:
141
+ # In case it is frozen by LoRA
142
+ for p in self.mm_projector.parameters():
143
+ p.requires_grad = True
144
+
145
+ if pretrain_mm_mlp_adapter is not None:
146
+ mm_projector_weights = torch.load(
147
+ pretrain_mm_mlp_adapter, map_location="cpu"
148
+ )
149
+
150
+ def get_w(weights, keyword):
151
+ return {
152
+ k.split(keyword + ".")[1]: v
153
+ for k, v in weights.items()
154
+ if keyword in k
155
+ }
156
+
157
+ incompatible_keys = self.mm_projector.load_state_dict(
158
+ get_w(mm_projector_weights, "mm_projector")
159
+ )
160
+ rank0_print(
161
+ f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}"
162
+ )
163
+ incompatible_keys = self.vision_resampler.load_state_dict(
164
+ get_w(mm_projector_weights, "vision_resampler"), strict=False
165
+ )
166
+ rank0_print(
167
+ f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}"
168
+ )
169
+
170
+ def initialize_speech_modules(self, model_args, fsdp=None):
171
+ self.config.speech_encoder = getattr(model_args, "speech_encoder", None)
172
+ self.config.speech_encoder_type = getattr(
173
+ model_args, "speech_encoder_type", None
174
+ )
175
+ self.config.speech_projector_type = getattr(
176
+ model_args, "speech_projector_type", "linear"
177
+ )
178
+ self.config.speech_encoder_ds_rate = getattr(
179
+ model_args, "speech_encoder_ds_rate", 5
180
+ )
181
+ self.config.speech_encoder_hidden_size = getattr(
182
+ model_args, "speech_encoder_hidden_size", 1280
183
+ )
184
+ self.config.delay_load_audio = getattr(model_args, "delay_load_audio", True)
185
+
186
+ if self.get_speech_encoder() is None:
187
+ speech_encoder = build_speech_encoder(self.config)
188
+ if fsdp is not None and len(fsdp) > 0:
189
+ self.speech_encoder = [speech_encoder]
190
+ else:
191
+ self.speech_encoder = speech_encoder
192
+ else:
193
+ if fsdp is not None and len(fsdp) > 0:
194
+ speech_encoder = self.speech_encoder[0]
195
+ else:
196
+ speech_encoder = self.speech_encoder
197
+ speech_encoder.load_model(self.config)
198
+
199
+ if getattr(self, "speech_projector", None) is None:
200
+ self.speech_projector = build_speech_projector(self.config)
201
+ else:
202
+ # In case it is frozen by LoRA
203
+ for p in self.speech_projector.parameters():
204
+ p.requires_grad = True
205
+
206
+ if model_args.pretrain_speech_projector is not None:
207
+ pretrain_speech_projector_weights = torch.load(
208
+ model_args.pretrain_speech_projector, map_location="cpu"
209
+ )
210
+
211
+ def get_w(weights, keyword):
212
+ return {
213
+ k.split(keyword + ".")[1]: v
214
+ for k, v in weights.items()
215
+ if keyword in k
216
+ }
217
+
218
+ self.speech_projector.load_state_dict(
219
+ get_w(pretrain_speech_projector_weights, "speech_projector"),
220
+ strict=False,
221
+ )
222
+
223
+
224
+ def unpad_image(tensor, original_size):
225
+ """
226
+ Unpads a PyTorch tensor of a padded and resized image.
227
+
228
+ Args:
229
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
230
+ original_size (tuple): The original size of the image (height, width).
231
+
232
+ Returns:
233
+ torch.Tensor: The unpadded image tensor.
234
+ """
235
+ original_width, original_height = original_size
236
+ current_height, current_width = tensor.shape[1:]
237
+
238
+ # Compute aspect ratios
239
+ original_aspect_ratio = original_width / original_height
240
+ current_aspect_ratio = current_width / current_height
241
+
242
+ # Determine padding size and direction
243
+ if original_aspect_ratio > current_aspect_ratio:
244
+ # Padding was added to the height
245
+ scale_factor = current_width / original_width
246
+ new_height = int(original_height * scale_factor)
247
+ padding = (current_height - new_height) // 2
248
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
249
+ else:
250
+ # Padding was added to the width
251
+ scale_factor = current_height / original_height
252
+ new_width = int(original_width * scale_factor)
253
+ padding = (current_width - new_width) // 2
254
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
255
+
256
+ return unpadded_tensor
257
+
258
+
259
+ class EgoGPTMetaForCausalLM(ABC):
260
+ @abstractmethod
261
+ def get_model(self):
262
+ pass
263
+
264
+ def get_speech_encoder(self):
265
+ return self.get_model().get_speech_encoder()
266
+
267
+ def get_speech_projector(self):
268
+ return self.get_model().speech_projector
269
+
270
+ def get_vision_tower(self):
271
+ return self.get_model().get_vision_tower()
272
+
273
+ def get_2dPool(self, image_feature, stride=2):
274
+ height = width = self.get_vision_tower().num_patches_per_side
275
+ num_frames, num_tokens, num_dim = image_feature.shape
276
+ image_feature = image_feature.view(num_frames, height, width, -1)
277
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
278
+ image_feature = nn.functional.avg_pool2d(image_feature, stride)
279
+ # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
280
+ # if self.config.mm_spatial_pool_mode == "average":
281
+ # image_feature = nn.functional.avg_pool2d(image_feature, stride)
282
+ # elif self.config.mm_spatial_pool_mode == "max":
283
+ # image_feature = nn.functional.max_pool2d(image_feature, stride)
284
+ # elif self.config.mm_spatial_pool_mode == "bilinear":
285
+ # height, width = image_feature.shape[2:]
286
+ # scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
287
+ # image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
288
+ # else:
289
+ # raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
290
+ image_feature = image_feature.permute(0, 2, 3, 1)
291
+ image_feature = image_feature.view(num_frames, -1, num_dim)
292
+ return image_feature
293
+
294
+ def encode_images(self, images):
295
+ image_features = self.get_model().get_vision_tower()(images)
296
+ # image_features = self.get_model().vision_resampler(image_features, images=images)
297
+ image_features = self.get_model().mm_projector(image_features)
298
+ return image_features
299
+
300
+ def encode_speech(self, speech, speech_lengths):
301
+ # audio cuttting
302
+ speech_encoder_type = self.config.speech_encoder_type
303
+ speech_encoder = self.get_speech_encoder()
304
+ if "whisper" in speech_encoder_type.lower():
305
+ encoder_outs = speech_encoder(speech.permute(0, 2, 1))
306
+ speech_lengths = (speech_lengths + 1) // 2
307
+ else:
308
+ raise ValueError(f"Unknown speech encoder: {speech_encoder}")
309
+ speech_projector_type = self.config.speech_projector_type
310
+ speech_projector = self.get_speech_projector()
311
+ if speech_projector_type == "linear":
312
+ encoder_outs = speech_projector(encoder_outs)
313
+ speech_lengths = speech_lengths // speech_projector.k
314
+ else:
315
+ raise ValueError(f"Unknown speech projector: {speech_projector_type}")
316
+ speech_features = [
317
+ encoder_outs[i, : speech_lengths[i]] for i in range(len(encoder_outs))
318
+ ]
319
+ return speech_features
320
+
321
+ def add_token_per_grid(self, image_feature):
322
+ resize_h = int(math.sqrt(image_feature.shape[1]))
323
+ num_frames = image_feature.shape[0]
324
+ feature_dim = image_feature.shape[-1]
325
+
326
+ image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
327
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
328
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
329
+ image_feature = torch.cat(
330
+ (
331
+ image_feature,
332
+ self.model.image_newline[:, None, None]
333
+ .expand(*image_feature.shape[:-1], 1)
334
+ .to(image_feature.device),
335
+ ),
336
+ dim=-1,
337
+ )
338
+ if getattr(self.config, "add_faster_video", False):
339
+ # import pdb; pdb.set_trace()
340
+ # (3584, 832, 14) -> (3584, 64, 13, 14)
341
+ image_feature = image_feature.view(feature_dim, num_frames, resize_h, -1)
342
+ # (3584, 64, 13, 14) -> (64, 13, 14, 3584)
343
+ image_feature = image_feature.permute(1, 2, 3, 0).contiguous()
344
+ # (64, 13, 14, 3584) -> (64, 13*14, 3584)
345
+ image_feature = image_feature.flatten(1, 2)
346
+ # import pdb; pdb.set_trace()
347
+ return image_feature
348
+ # import pdb; pdb.set_trace()
349
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
350
+ return image_feature
351
+
352
+ def prepare_inputs_labels_for_speech_and_text(
353
+ self,
354
+ input_ids,
355
+ position_ids,
356
+ attention_mask,
357
+ past_key_values,
358
+ labels,
359
+ speech,
360
+ speech_lengths,
361
+ images,
362
+ image_sizes=None,
363
+ modalities=["image"],
364
+ ):
365
+ vision_tower = self.get_vision_tower()
366
+ # rank_print(modalities)
367
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
368
+ return (
369
+ input_ids,
370
+ position_ids,
371
+ attention_mask,
372
+ past_key_values,
373
+ None,
374
+ labels,
375
+ )
376
+ speech_encoder = self.get_speech_encoder()
377
+ if speech_encoder is None or speech is None or input_ids.shape[1] == 1:
378
+ return (
379
+ input_ids,
380
+ position_ids,
381
+ attention_mask,
382
+ past_key_values,
383
+ None,
384
+ labels,
385
+ )
386
+
387
+ speech_features = self.encode_speech(speech, speech_lengths)
388
+
389
+ if isinstance(modalities, str):
390
+ modalities = [modalities]
391
+
392
+ # import pdb; pdb.set_trace()
393
+ if type(images) is list or images.ndim == 5:
394
+ if type(images) is list:
395
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
396
+
397
+ video_idx_in_batch = []
398
+ for _ in range(len(modalities)):
399
+ if modalities[_] == "video":
400
+ video_idx_in_batch.append(_)
401
+
402
+ # print(f"Images: {images}, {type(images)}, {len(images)}")
403
+ # print(f"Video idx in batch: {modalities}")
404
+ images_list = []
405
+ for image in images:
406
+ if image.ndim == 4:
407
+ images_list.append(image)
408
+ else:
409
+ images_list.append(image.unsqueeze(0))
410
+
411
+ # concat_images = torch.cat([torch.tensor(image) for image in images_list], dim=0)
412
+ concat_images = torch.cat([image for image in images_list], dim=0)
413
+ split_sizes = [image.shape[0] for image in images_list]
414
+ concat_images.requires_grad_(True)
415
+ encoded_image_features = self.encode_images(concat_images)
416
+ # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
417
+
418
+ # This is a list, each element is [num_images, patch * patch, dim]
419
+ # rank_print(f"Concat images : {concat_images.shape}")
420
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
421
+ image_features = []
422
+ for idx, image_feat in enumerate(encoded_image_features):
423
+ if idx in video_idx_in_batch:
424
+ image_features.append(self.get_2dPool(image_feat))
425
+ else:
426
+ image_features.append(image_feat)
427
+ # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
428
+ # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
429
+ # image_features = torch.split(image_features, split_sizes, dim=0)
430
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
431
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
432
+ mm_newline_position = getattr(
433
+ self.config, "mm_newline_position", "one_token"
434
+ )
435
+
436
+ if mm_patch_merge_type == "flat":
437
+ image_features = [x.flatten(0, 1) for x in image_features]
438
+
439
+ elif mm_patch_merge_type.startswith("spatial"):
440
+ new_image_features = []
441
+ for image_idx, image_feature in enumerate(image_features):
442
+ # FIXME: now assume the image is square, and split to 2x2 patches
443
+ # num_patches = h * w, where h = w = sqrt(num_patches)
444
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
445
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
446
+ # rank0_print("At least we are reaching here")
447
+ # import pdb; pdb.set_trace()
448
+ if image_idx in video_idx_in_batch: # video operations
449
+ # rank0_print("Video")
450
+ if mm_newline_position == "grid":
451
+ # Grid-wise
452
+ image_feature = self.add_token_per_grid(image_feature)
453
+ if getattr(self.config, "add_faster_video", False):
454
+ faster_video_feature = self.add_token_per_grid(
455
+ all_faster_video_features[image_idx]
456
+ )
457
+ # Add a token for each frame
458
+ concat_slow_fater_token = []
459
+ # import pdb; pdb.set_trace()
460
+ for _ in range(image_feature.shape[0]):
461
+ if _ % self.config.faster_token_stride == 0:
462
+ concat_slow_fater_token.append(
463
+ torch.cat(
464
+ (
465
+ image_feature[_],
466
+ self.model.faster_token[None].to(
467
+ image_feature.device
468
+ ),
469
+ ),
470
+ dim=0,
471
+ )
472
+ )
473
+ else:
474
+ concat_slow_fater_token.append(
475
+ torch.cat(
476
+ (
477
+ faster_video_feature[_],
478
+ self.model.faster_token[None].to(
479
+ image_feature.device
480
+ ),
481
+ ),
482
+ dim=0,
483
+ )
484
+ )
485
+ # import pdb; pdb.set_trace()
486
+ image_feature = torch.cat(concat_slow_fater_token)
487
+
488
+ new_image_features.append(image_feature)
489
+ elif mm_newline_position == "frame":
490
+ # Frame-wise
491
+ image_feature = self.add_token_per_frame(image_feature)
492
+
493
+ new_image_features.append(image_feature.flatten(0, 1))
494
+
495
+ elif mm_newline_position == "one_token":
496
+ # one-token
497
+ image_feature = image_feature.flatten(0, 1)
498
+ if "unpad" in mm_patch_merge_type:
499
+ image_feature = torch.cat(
500
+ (
501
+ image_feature,
502
+ self.model.image_newline[None].to(
503
+ image_feature.device
504
+ ),
505
+ ),
506
+ dim=0,
507
+ )
508
+ new_image_features.append(image_feature)
509
+ elif mm_newline_position == "no_token":
510
+ new_image_features.append(image_feature.flatten(0, 1))
511
+ else:
512
+ raise ValueError(
513
+ f"Unexpected mm_newline_position: {mm_newline_position}"
514
+ )
515
+ elif (
516
+ image_feature.shape[0] > 1
517
+ ): # multi patches and multi images operations
518
+ # rank0_print("Single-images")
519
+ base_image_feature = image_feature[0]
520
+ image_feature = image_feature[1:]
521
+ height = width = self.get_vision_tower().num_patches_per_side
522
+ assert height * width == base_image_feature.shape[0]
523
+
524
+ if "anyres_max" in image_aspect_ratio:
525
+ matched_anyres_max_num_patches = re.match(
526
+ r"anyres_max_(\d+)", image_aspect_ratio
527
+ )
528
+ if matched_anyres_max_num_patches:
529
+ max_num_patches = int(
530
+ matched_anyres_max_num_patches.group(1)
531
+ )
532
+
533
+ if (
534
+ image_aspect_ratio == "anyres"
535
+ or "anyres_max" in image_aspect_ratio
536
+ ):
537
+ if hasattr(self.get_vision_tower(), "image_size"):
538
+ vision_tower_image_size = (
539
+ self.get_vision_tower().image_size
540
+ )
541
+ else:
542
+ raise ValueError(
543
+ "vision_tower_image_size is not found in the vision tower."
544
+ )
545
+ try:
546
+ (
547
+ num_patch_width,
548
+ num_patch_height,
549
+ ) = get_anyres_image_grid_shape(
550
+ image_sizes[image_idx],
551
+ self.config.image_grid_pinpoints,
552
+ vision_tower_image_size,
553
+ )
554
+ except Exception as e:
555
+ rank0_print(f"Error: {e}")
556
+ num_patch_width, num_patch_height = 2, 2
557
+ image_feature = image_feature.view(
558
+ num_patch_height, num_patch_width, height, width, -1
559
+ )
560
+ else:
561
+ image_feature = image_feature.view(2, 2, height, width, -1)
562
+
563
+ if "maxpool2x2" in mm_patch_merge_type:
564
+ image_feature = image_feature.permute(
565
+ 4, 0, 2, 1, 3
566
+ ).contiguous()
567
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
568
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
569
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
570
+ elif (
571
+ "unpad" in mm_patch_merge_type
572
+ and "anyres_max" in image_aspect_ratio
573
+ and matched_anyres_max_num_patches
574
+ ):
575
+ unit = image_feature.shape[2]
576
+ image_feature = image_feature.permute(
577
+ 4, 0, 2, 1, 3
578
+ ).contiguous()
579
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
580
+ image_feature = unpad_image(
581
+ image_feature, image_sizes[image_idx]
582
+ )
583
+ c, h, w = image_feature.shape
584
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
585
+ if times > 1.1:
586
+ image_feature = image_feature[None]
587
+ image_feature = nn.functional.interpolate(
588
+ image_feature,
589
+ [int(h // times), int(w // times)],
590
+ mode="bilinear",
591
+ )[0]
592
+ image_feature = torch.cat(
593
+ (
594
+ image_feature,
595
+ self.model.image_newline[:, None, None]
596
+ .expand(*image_feature.shape[:-1], 1)
597
+ .to(image_feature.device),
598
+ ),
599
+ dim=-1,
600
+ )
601
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
602
+ elif "unpad" in mm_patch_merge_type:
603
+ image_feature = image_feature.permute(
604
+ 4, 0, 2, 1, 3
605
+ ).contiguous()
606
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
607
+ image_feature = unpad_image(
608
+ image_feature, image_sizes[image_idx]
609
+ )
610
+ image_feature = torch.cat(
611
+ (
612
+ image_feature,
613
+ self.model.image_newline[:, None, None]
614
+ .expand(*image_feature.shape[:-1], 1)
615
+ .to(image_feature.device),
616
+ ),
617
+ dim=-1,
618
+ )
619
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
620
+ else:
621
+ image_feature = image_feature.permute(
622
+ 0, 2, 1, 3, 4
623
+ ).contiguous()
624
+ image_feature = image_feature.flatten(0, 3)
625
+ if "nobase" in mm_patch_merge_type:
626
+ pass
627
+ else:
628
+ image_feature = torch.cat(
629
+ (base_image_feature, image_feature), dim=0
630
+ )
631
+ new_image_features.append(image_feature)
632
+ else: # single image operations
633
+ image_feature = image_feature[0]
634
+ if "unpad" in mm_patch_merge_type:
635
+ image_feature = torch.cat(
636
+ (image_feature, self.model.image_newline[None]), dim=0
637
+ )
638
+
639
+ new_image_features.append(image_feature)
640
+ image_features = new_image_features
641
+ else:
642
+ raise ValueError(
643
+ f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}"
644
+ )
645
+ else:
646
+ image_features = self.encode_images(images)
647
+
648
+ # TODO: image start / end is not implemented here to support pretraining.
649
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
650
+ self.config, "mm_use_im_start_end", False
651
+ ):
652
+ raise NotImplementedError
653
+ # Let's just add dummy tensors if they do not exist,
654
+ # it is a headache to deal with None all the time.
655
+ # But it is not ideal, and if you have a better idea,
656
+ # please open an issue / submit a PR, thanks.
657
+ _labels = labels
658
+ _position_ids = position_ids
659
+ _attention_mask = attention_mask
660
+ if attention_mask is None:
661
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
662
+ else:
663
+ attention_mask = attention_mask.bool()
664
+ if position_ids is None:
665
+ position_ids = torch.arange(
666
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
667
+ )
668
+ if labels is None:
669
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
670
+
671
+ # remove the padding using attention_mask -- FIXME
672
+ _input_ids = input_ids
673
+ input_ids = [
674
+ cur_input_ids[cur_attention_mask]
675
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
676
+ ]
677
+ labels = [
678
+ cur_labels[cur_attention_mask]
679
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
680
+ ]
681
+ new_input_embeds = []
682
+ new_labels = []
683
+ cur_speech_idx = 0
684
+ cur_image_idx = 0
685
+ for batch_idx, cur_input_ids in enumerate(input_ids):
686
+ num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
687
+ num_image = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
688
+ # if num_speech:
689
+ # print("has <speech>")
690
+ # if num_image:
691
+ # print("has <image>")
692
+ num_speech_images = num_speech + num_image
693
+
694
+ if num_speech_images == 0:
695
+ cur_speech_features = speech_features[cur_speech_idx]
696
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
697
+ cur_input_embeds = torch.cat(
698
+ [cur_input_embeds_1, cur_speech_features[0:0]], dim=0
699
+ )
700
+ new_input_embeds.append(cur_input_embeds)
701
+ new_labels.append(labels[batch_idx])
702
+ cur_speech_idx += 1
703
+ cur_image_idx += 1
704
+ continue
705
+
706
+ multimodal_token_indices = (
707
+ [-1]
708
+ + torch.where(
709
+ (cur_input_ids == SPEECH_TOKEN_INDEX)
710
+ | (cur_input_ids == IMAGE_TOKEN_INDEX)
711
+ )[0].tolist()
712
+ + [cur_input_ids.shape[0]]
713
+ )
714
+
715
+ cur_input_ids_nospeech_image = []
716
+ cur_labels = labels[batch_idx]
717
+ cur_labels_nospeech_image = []
718
+ for i in range(len(multimodal_token_indices) - 1):
719
+ cur_input_ids_nospeech_image.append(
720
+ cur_input_ids[
721
+ multimodal_token_indices[i]
722
+ + 1 : multimodal_token_indices[i + 1]
723
+ ]
724
+ )
725
+ cur_labels_nospeech_image.append(
726
+ cur_labels[
727
+ multimodal_token_indices[i]
728
+ + 1 : multimodal_token_indices[i + 1]
729
+ ]
730
+ )
731
+
732
+ split_sizes = [x.shape[0] for x in cur_labels_nospeech_image]
733
+ cur_input_embeds = self.get_model().embed_tokens(
734
+ torch.cat(cur_input_ids_nospeech_image)
735
+ )
736
+ cur_input_embeds_no_speech_image = torch.split(
737
+ cur_input_embeds, split_sizes, dim=0
738
+ )
739
+ cur_new_input_embeds = []
740
+ cur_new_labels = []
741
+
742
+ for i in range(num_speech_images + 1):
743
+ cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i])
744
+ cur_new_labels.append(cur_labels_nospeech_image[i])
745
+ if i < num_speech_images:
746
+ if i < num_image:
747
+ cur_images_features = image_features[cur_image_idx]
748
+ cur_image_idx += 1
749
+ cur_new_input_embeds.append(cur_images_features)
750
+ cur_new_labels.append(
751
+ torch.full(
752
+ (cur_images_features.shape[0],),
753
+ IGNORE_INDEX,
754
+ device=cur_labels.device,
755
+ dtype=cur_labels.dtype,
756
+ )
757
+ )
758
+ else:
759
+ cur_speech_features = speech_features[cur_speech_idx]
760
+ cur_speech_idx += 1
761
+ cur_new_input_embeds.append(cur_speech_features)
762
+ cur_new_labels.append(
763
+ torch.full(
764
+ (cur_speech_features.shape[0],),
765
+ IGNORE_INDEX,
766
+ device=cur_labels.device,
767
+ dtype=cur_labels.dtype,
768
+ )
769
+ )
770
+
771
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
772
+
773
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
774
+ cur_new_labels = torch.cat(cur_new_labels)
775
+
776
+ if num_image == 0:
777
+ cur_new_input_embeds = torch.cat(
778
+ [cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0
779
+ )
780
+ cur_image_idx += 1
781
+
782
+ if num_speech == 0:
783
+ cur_new_input_embeds = torch.cat(
784
+ [cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0
785
+ )
786
+ cur_speech_idx += 1
787
+
788
+ new_input_embeds.append(cur_new_input_embeds)
789
+ new_labels.append(cur_new_labels)
790
+
791
+ # Truncate sequences to max length as speech features can make the sequence longer
792
+ tokenizer_model_max_length = getattr(
793
+ self.config, "tokenizer_model_max_length", None
794
+ )
795
+ if tokenizer_model_max_length is not None:
796
+ new_input_embeds = [
797
+ x[:tokenizer_model_max_length] for x in new_input_embeds
798
+ ]
799
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
800
+
801
+ # Combine them
802
+ max_len = max(x.shape[0] for x in new_input_embeds)
803
+ batch_size = len(new_input_embeds)
804
+
805
+ new_input_embeds_padded = []
806
+ new_labels_padded = torch.full(
807
+ (batch_size, max_len),
808
+ IGNORE_INDEX,
809
+ dtype=new_labels[0].dtype,
810
+ device=new_labels[0].device,
811
+ )
812
+ attention_mask = torch.zeros(
813
+ (batch_size, max_len),
814
+ dtype=attention_mask.dtype,
815
+ device=attention_mask.device,
816
+ )
817
+ position_ids = torch.zeros(
818
+ (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
819
+ )
820
+
821
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
822
+ zip(new_input_embeds, new_labels)
823
+ ):
824
+ cur_len = cur_new_embed.shape[0]
825
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
826
+ new_input_embeds_padded.append(
827
+ torch.cat(
828
+ (
829
+ torch.zeros(
830
+ (max_len - cur_len, cur_new_embed.shape[1]),
831
+ dtype=cur_new_embed.dtype,
832
+ device=cur_new_embed.device,
833
+ ),
834
+ cur_new_embed,
835
+ ),
836
+ dim=0,
837
+ )
838
+ )
839
+ if cur_len > 0:
840
+ new_labels_padded[i, -cur_len:] = cur_new_labels
841
+ attention_mask[i, -cur_len:] = True
842
+ position_ids[i, -cur_len:] = torch.arange(
843
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
844
+ )
845
+ else:
846
+ new_input_embeds_padded.append(
847
+ torch.cat(
848
+ (
849
+ cur_new_embed,
850
+ torch.zeros(
851
+ (max_len - cur_len, cur_new_embed.shape[1]),
852
+ dtype=cur_new_embed.dtype,
853
+ device=cur_new_embed.device,
854
+ ),
855
+ ),
856
+ dim=0,
857
+ )
858
+ )
859
+ if cur_len > 0:
860
+ new_labels_padded[i, :cur_len] = cur_new_labels
861
+ attention_mask[i, :cur_len] = True
862
+ position_ids[i, :cur_len] = torch.arange(
863
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
864
+ )
865
+
866
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
867
+ if _labels is None:
868
+ new_labels = None
869
+ else:
870
+ new_labels = new_labels_padded
871
+
872
+ if _attention_mask is None:
873
+ attention_mask = None
874
+ else:
875
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
876
+
877
+ if _position_ids is None:
878
+ position_ids = None
879
+
880
+ return (
881
+ None,
882
+ position_ids,
883
+ attention_mask,
884
+ past_key_values,
885
+ new_input_embeds,
886
+ new_labels,
887
+ )
888
+
889
+ def prepare_inputs_labels_for_speech_and_text_debug(
890
+ self,
891
+ input_ids,
892
+ position_ids,
893
+ attention_mask,
894
+ past_key_values,
895
+ labels,
896
+ speech,
897
+ speech_lengths,
898
+ images,
899
+ image_sizes=None,
900
+ modalities=["image"],
901
+ ):
902
+ # vision_tower = self.get_vision_tower()
903
+ # # rank_print(modalities)
904
+ # if vision_tower is None or images is None or input_ids.shape[1] == 1:
905
+ # return input_ids, position_ids, attention_mask, past_key_values, None, labels
906
+ # speech_encoder = self.get_speech_encoder()
907
+ # if speech_encoder is None or speech is None or input_ids.shape[1] == 1:
908
+ # return input_ids, position_ids, attention_mask, past_key_values, None, labels
909
+
910
+ speech_features = self.encode_speech(speech, speech_lengths)
911
+
912
+ if isinstance(modalities, str):
913
+ modalities = [modalities]
914
+
915
+ # import pdb; pdb.set_trace()
916
+ if type(images) is list or images.ndim == 5:
917
+ if type(images) is list:
918
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
919
+
920
+ video_idx_in_batch = []
921
+ for _ in range(len(modalities)):
922
+ if modalities[_] == "video":
923
+ video_idx_in_batch.append(_)
924
+
925
+ # print(f"Images: {images}, {type(images)}, {len(images)}")
926
+ # print(f"Video idx in batch: {modalities}")
927
+ images_list = []
928
+ for image in images:
929
+ if image.ndim == 4:
930
+ images_list.append(image)
931
+ else:
932
+ images_list.append(image.unsqueeze(0))
933
+
934
+ # concat_images = torch.cat([torch.tensor(image) for image in images_list], dim=0)
935
+ concat_images = torch.cat([image for image in images_list], dim=0)
936
+ split_sizes = [image.shape[0] for image in images_list]
937
+ concat_images.requires_grad_(True)
938
+ encoded_image_features = self.encode_images(concat_images)
939
+ # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
940
+
941
+ # This is a list, each element is [num_images, patch * patch, dim]
942
+ # rank_print(f"Concat images : {concat_images.shape}")
943
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
944
+ image_features = []
945
+ for idx, image_feat in enumerate(encoded_image_features):
946
+ if idx in video_idx_in_batch:
947
+ image_features.append(self.get_2dPool(image_feat))
948
+ else:
949
+ image_features.append(image_feat)
950
+ # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
951
+ # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
952
+ # image_features = torch.split(image_features, split_sizes, dim=0)
953
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
954
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
955
+ mm_newline_position = getattr(
956
+ self.config, "mm_newline_position", "one_token"
957
+ )
958
+
959
+ if mm_patch_merge_type == "flat":
960
+ image_features = [x.flatten(0, 1) for x in image_features]
961
+
962
+ elif mm_patch_merge_type.startswith("spatial"):
963
+ new_image_features = []
964
+ for image_idx, image_feature in enumerate(image_features):
965
+ # FIXME: now assume the image is square, and split to 2x2 patches
966
+ # num_patches = h * w, where h = w = sqrt(num_patches)
967
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
968
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
969
+ # rank0_print("At least we are reaching here")
970
+ # import pdb; pdb.set_trace()
971
+ if image_idx in video_idx_in_batch: # video operations
972
+ # rank0_print("Video")
973
+ if mm_newline_position == "grid":
974
+ # Grid-wise
975
+ image_feature = self.add_token_per_grid(image_feature)
976
+ new_image_features.append(image_feature)
977
+ elif mm_newline_position == "frame":
978
+ # Frame-wise
979
+ image_feature = self.add_token_per_frame(image_feature)
980
+ new_image_features.append(image_feature.flatten(0, 1))
981
+ elif mm_newline_position == "one_token":
982
+ # one-token
983
+ image_feature = image_feature.flatten(0, 1)
984
+ if "unpad" in mm_patch_merge_type:
985
+ image_feature = torch.cat(
986
+ (
987
+ image_feature,
988
+ self.model.image_newline[None].to(
989
+ image_feature.device
990
+ ),
991
+ ),
992
+ dim=0,
993
+ )
994
+ new_image_features.append(image_feature)
995
+ elif mm_newline_position == "no_token":
996
+ new_image_features.append(image_feature.flatten(0, 1))
997
+ else:
998
+ raise ValueError(
999
+ f"Unexpected mm_newline_position: {mm_newline_position}"
1000
+ )
1001
+ elif (
1002
+ image_feature.shape[0] > 1
1003
+ ): # multi patches and multi images operations
1004
+ # rank0_print("Single-images")
1005
+ base_image_feature = image_feature[0]
1006
+ image_feature = image_feature[1:]
1007
+ height = width = self.get_vision_tower().num_patches_per_side
1008
+ assert height * width == base_image_feature.shape[0]
1009
+
1010
+ if "anyres_max" in image_aspect_ratio:
1011
+ matched_anyres_max_num_patches = re.match(
1012
+ r"anyres_max_(\d+)", image_aspect_ratio
1013
+ )
1014
+ if matched_anyres_max_num_patches:
1015
+ max_num_patches = int(
1016
+ matched_anyres_max_num_patches.group(1)
1017
+ )
1018
+
1019
+ if (
1020
+ image_aspect_ratio == "anyres"
1021
+ or "anyres_max" in image_aspect_ratio
1022
+ ):
1023
+ if hasattr(self.get_vision_tower(), "image_size"):
1024
+ vision_tower_image_size = (
1025
+ self.get_vision_tower().image_size
1026
+ )
1027
+ else:
1028
+ raise ValueError(
1029
+ "vision_tower_image_size is not found in the vision tower."
1030
+ )
1031
+ try:
1032
+ (
1033
+ num_patch_width,
1034
+ num_patch_height,
1035
+ ) = get_anyres_image_grid_shape(
1036
+ image_sizes[image_idx],
1037
+ self.config.image_grid_pinpoints,
1038
+ vision_tower_image_size,
1039
+ )
1040
+ except Exception as e:
1041
+ rank0_print(f"Error: {e}")
1042
+ num_patch_width, num_patch_height = 2, 2
1043
+ image_feature = image_feature.view(
1044
+ num_patch_height, num_patch_width, height, width, -1
1045
+ )
1046
+ else:
1047
+ image_feature = image_feature.view(2, 2, height, width, -1)
1048
+
1049
+ if "maxpool2x2" in mm_patch_merge_type:
1050
+ image_feature = image_feature.permute(
1051
+ 4, 0, 2, 1, 3
1052
+ ).contiguous()
1053
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
1054
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
1055
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
1056
+ elif (
1057
+ "unpad" in mm_patch_merge_type
1058
+ and "anyres_max" in image_aspect_ratio
1059
+ and matched_anyres_max_num_patches
1060
+ ):
1061
+ unit = image_feature.shape[2]
1062
+ image_feature = image_feature.permute(
1063
+ 4, 0, 2, 1, 3
1064
+ ).contiguous()
1065
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
1066
+ image_feature = unpad_image(
1067
+ image_feature, image_sizes[image_idx]
1068
+ )
1069
+ c, h, w = image_feature.shape
1070
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
1071
+ if times > 1.1:
1072
+ image_feature = image_feature[None]
1073
+ image_feature = nn.functional.interpolate(
1074
+ image_feature,
1075
+ [int(h // times), int(w // times)],
1076
+ mode="bilinear",
1077
+ )[0]
1078
+ image_feature = torch.cat(
1079
+ (
1080
+ image_feature,
1081
+ self.model.image_newline[:, None, None]
1082
+ .expand(*image_feature.shape[:-1], 1)
1083
+ .to(image_feature.device),
1084
+ ),
1085
+ dim=-1,
1086
+ )
1087
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
1088
+ elif "unpad" in mm_patch_merge_type:
1089
+ image_feature = image_feature.permute(
1090
+ 4, 0, 2, 1, 3
1091
+ ).contiguous()
1092
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
1093
+ image_feature = unpad_image(
1094
+ image_feature, image_sizes[image_idx]
1095
+ )
1096
+ image_feature = torch.cat(
1097
+ (
1098
+ image_feature,
1099
+ self.model.image_newline[:, None, None]
1100
+ .expand(*image_feature.shape[:-1], 1)
1101
+ .to(image_feature.device),
1102
+ ),
1103
+ dim=-1,
1104
+ )
1105
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
1106
+ else:
1107
+ image_feature = image_feature.permute(
1108
+ 0, 2, 1, 3, 4
1109
+ ).contiguous()
1110
+ image_feature = image_feature.flatten(0, 3)
1111
+ if "nobase" in mm_patch_merge_type:
1112
+ pass
1113
+ else:
1114
+ image_feature = torch.cat(
1115
+ (base_image_feature, image_feature), dim=0
1116
+ )
1117
+ new_image_features.append(image_feature)
1118
+ else: # single image operations
1119
+ image_feature = image_feature[0]
1120
+ if "unpad" in mm_patch_merge_type:
1121
+ image_feature = torch.cat(
1122
+ (image_feature, self.model.image_newline[None]), dim=0
1123
+ )
1124
+
1125
+ new_image_features.append(image_feature)
1126
+ image_features = new_image_features
1127
+ else:
1128
+ raise ValueError(
1129
+ f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}"
1130
+ )
1131
+ else:
1132
+ image_features = self.encode_images(images)
1133
+
1134
+ # TODO: image start / end is not implemented here to support pretraining.
1135
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
1136
+ self.config, "mm_use_im_start_end", False
1137
+ ):
1138
+ raise NotImplementedError
1139
+ # Let's just add dummy tensors if they do not exist,
1140
+ # it is a headache to deal with None all the time.
1141
+ # But it is not ideal, and if you have a better idea,
1142
+ # please open an issue / submit a PR, thanks.
1143
+ _labels = labels
1144
+ _position_ids = position_ids
1145
+ _attention_mask = attention_mask
1146
+ if attention_mask is None:
1147
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1148
+ else:
1149
+ attention_mask = attention_mask.bool()
1150
+ if position_ids is None:
1151
+ position_ids = torch.arange(
1152
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
1153
+ )
1154
+ if labels is None:
1155
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1156
+
1157
+ # remove the padding using attention_mask -- FIXME
1158
+ _input_ids = input_ids
1159
+ input_ids = [
1160
+ cur_input_ids[cur_attention_mask]
1161
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
1162
+ ]
1163
+ labels = [
1164
+ cur_labels[cur_attention_mask]
1165
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
1166
+ ]
1167
+ new_input_embeds = []
1168
+ new_labels = []
1169
+ cur_speech_idx = 0
1170
+ cur_image_idx = 0
1171
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1172
+ num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum()
1173
+ num_image = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1174
+ if num_speech + num_image == 0:
1175
+ cur_speech_features = speech_features[cur_speech_idx]
1176
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1177
+ cur_input_embeds = torch.cat(
1178
+ [cur_input_embeds_1, cur_speech_features[0:0]], dim=0
1179
+ )
1180
+ new_input_embeds.append(cur_input_embeds)
1181
+ new_labels.append(labels[batch_idx])
1182
+ cur_speech_idx += 1
1183
+ cur_image_idx += 1
1184
+ continue
1185
+
1186
+ multimodal_token_indices = sorted(
1187
+ [-1]
1188
+ + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist()
1189
+ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
1190
+ + [cur_input_ids.shape[0]]
1191
+ )
1192
+ cur_input_ids_nospeech = []
1193
+ cur_labels = labels[batch_idx]
1194
+ cur_labels_nospeech = []
1195
+ for i in range(len(multimodal_token_indices) - 1):
1196
+ cur_input_ids_nospeech.append(
1197
+ cur_input_ids[
1198
+ multimodal_token_indices[i]
1199
+ + 1 : multimodal_token_indices[i + 1]
1200
+ ]
1201
+ )
1202
+ cur_labels_nospeech.append(
1203
+ cur_labels[
1204
+ multimodal_token_indices[i]
1205
+ + 1 : multimodal_token_indices[i + 1]
1206
+ ]
1207
+ )
1208
+
1209
+ split_sizes = [x.shape[0] for x in cur_labels_nospeech]
1210
+ cur_input_embeds = self.get_model().embed_tokens(
1211
+ torch.cat(cur_input_ids_nospeech)
1212
+ )
1213
+ cur_input_embeds_no_speech = torch.split(
1214
+ cur_input_embeds, split_sizes, dim=0
1215
+ )
1216
+ cur_new_input_embeds = []
1217
+ cur_new_labels = []
1218
+ for i in range(num_speech + num_image + 1):
1219
+ cur_new_input_embeds.append(cur_input_embeds_no_speech[i])
1220
+ cur_new_labels.append(cur_labels_nospeech[i])
1221
+ if cur_speech_idx < num_speech:
1222
+ try:
1223
+ cur_speech_features = speech_features[cur_speech_idx]
1224
+ except:
1225
+ cur_speech_features = speech_features[cur_speech_idx - 1]
1226
+ cur_speech_idx += 1
1227
+ cur_new_input_embeds.append(cur_speech_features)
1228
+ cur_new_labels.append(
1229
+ torch.full(
1230
+ (cur_speech_features.shape[0],),
1231
+ IGNORE_INDEX,
1232
+ device=cur_labels.device,
1233
+ dtype=cur_labels.dtype,
1234
+ )
1235
+ )
1236
+ if cur_image_idx < num_image:
1237
+ try:
1238
+ cur_image_features = image_features[cur_image_idx]
1239
+ except:
1240
+ cur_image_features = image_features[cur_image_idx - 1]
1241
+ cur_image_idx += 1
1242
+ cur_new_input_embeds.append(cur_image_features)
1243
+ cur_new_labels.append(
1244
+ torch.full(
1245
+ (cur_image_features.shape[0],),
1246
+ IGNORE_INDEX,
1247
+ device=cur_labels.device,
1248
+ dtype=cur_labels.dtype,
1249
+ )
1250
+ )
1251
+
1252
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1253
+
1254
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1255
+ cur_new_labels = torch.cat(cur_new_labels)
1256
+
1257
+ new_input_embeds.append(cur_new_input_embeds)
1258
+ new_labels.append(cur_new_labels)
1259
+
1260
+ # Truncate sequences to max length as speech features can make the sequence longer
1261
+ tokenizer_model_max_length = getattr(
1262
+ self.config, "tokenizer_model_max_length", None
1263
+ )
1264
+ if tokenizer_model_max_length is not None:
1265
+ new_input_embeds = [
1266
+ x[:tokenizer_model_max_length] for x in new_input_embeds
1267
+ ]
1268
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1269
+
1270
+ # Combine them
1271
+ max_len = max(x.shape[0] for x in new_input_embeds)
1272
+ batch_size = len(new_input_embeds)
1273
+
1274
+ new_input_embeds_padded = []
1275
+ new_labels_padded = torch.full(
1276
+ (batch_size, max_len),
1277
+ IGNORE_INDEX,
1278
+ dtype=new_labels[0].dtype,
1279
+ device=new_labels[0].device,
1280
+ )
1281
+ attention_mask = torch.zeros(
1282
+ (batch_size, max_len),
1283
+ dtype=attention_mask.dtype,
1284
+ device=attention_mask.device,
1285
+ )
1286
+ position_ids = torch.zeros(
1287
+ (batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
1288
+ )
1289
+
1290
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
1291
+ zip(new_input_embeds, new_labels)
1292
+ ):
1293
+ cur_len = cur_new_embed.shape[0]
1294
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
1295
+ new_input_embeds_padded.append(
1296
+ torch.cat(
1297
+ (
1298
+ torch.zeros(
1299
+ (max_len - cur_len, cur_new_embed.shape[1]),
1300
+ dtype=cur_new_embed.dtype,
1301
+ device=cur_new_embed.device,
1302
+ ),
1303
+ cur_new_embed,
1304
+ ),
1305
+ dim=0,
1306
+ )
1307
+ )
1308
+ if cur_len > 0:
1309
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1310
+ attention_mask[i, -cur_len:] = True
1311
+ position_ids[i, -cur_len:] = torch.arange(
1312
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
1313
+ )
1314
+ else:
1315
+ new_input_embeds_padded.append(
1316
+ torch.cat(
1317
+ (
1318
+ cur_new_embed,
1319
+ torch.zeros(
1320
+ (max_len - cur_len, cur_new_embed.shape[1]),
1321
+ dtype=cur_new_embed.dtype,
1322
+ device=cur_new_embed.device,
1323
+ ),
1324
+ ),
1325
+ dim=0,
1326
+ )
1327
+ )
1328
+ if cur_len > 0:
1329
+ new_labels_padded[i, :cur_len] = cur_new_labels
1330
+ attention_mask[i, :cur_len] = True
1331
+ position_ids[i, :cur_len] = torch.arange(
1332
+ 0, cur_len, dtype=position_ids.dtype, device=position_ids.device
1333
+ )
1334
+
1335
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1336
+ print(f"new_input_embeds: {new_input_embeds[0].shape}")
1337
+ if _labels is None:
1338
+ new_labels = None
1339
+ else:
1340
+ new_labels = new_labels_padded
1341
+
1342
+ if _attention_mask is None:
1343
+ attention_mask = None
1344
+ else:
1345
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1346
+
1347
+ if _position_ids is None:
1348
+ position_ids = None
1349
+
1350
+ return (
1351
+ None,
1352
+ position_ids,
1353
+ attention_mask,
1354
+ past_key_values,
1355
+ new_input_embeds,
1356
+ new_labels,
1357
+ )
egogpt/model/language_model/__pycache__/egogpt_llama.cpython-310.pyc ADDED
Binary file (3.83 kB). View file
 
egogpt/model/language_model/__pycache__/egogpt_qwen.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
egogpt/model/language_model/egogpt_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from transformers import (
21
+ AutoConfig,
22
+ AutoModelForCausalLM,
23
+ LlamaConfig,
24
+ LlamaForCausalLM,
25
+ LlamaModel,
26
+ )
27
+ from transformers.generation.utils import GenerateOutput
28
+ from transformers.modeling_outputs import CausalLMOutputWithPast
29
+
30
+ from ..egogpt_arch import EgoGPTMetaForCausalLM, EgoGPTMetaModel
31
+
32
+
33
+ class EgoGPTConfig(LlamaConfig):
34
+ model_type = "egogpt_llama"
35
+
36
+
37
+ class EgoGPTLlamaModel(EgoGPTMetaModel, LlamaModel):
38
+ config_class = EgoGPTConfig
39
+
40
+ def __init__(self, config: LlamaConfig):
41
+ super(EgoGPTLlamaModel, self).__init__(config)
42
+
43
+
44
+ class EgoGPTLlamaForCausalLM(LlamaForCausalLM, EgoGPTMetaForCausalLM):
45
+ config_class = EgoGPTConfig
46
+
47
+ def __init__(self, config):
48
+ super(LlamaForCausalLM, self).__init__(config)
49
+ self.model = EgoGPTLlamaModel(config)
50
+ self.pretraining_tp = config.pretraining_tp
51
+ self.vocab_size = config.vocab_size
52
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53
+
54
+ # Initialize weights and apply final processing
55
+ self.post_init()
56
+
57
+ def get_model(self):
58
+ return self.model
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ speech: Optional[torch.FloatTensor] = None,
72
+ speech_lengths: Optional[torch.LongTensor] = None,
73
+ return_dict: Optional[bool] = None,
74
+ cache_position: Optional[torch.LongTensor] = None,
75
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
76
+ if inputs_embeds is None:
77
+ (
78
+ input_ids,
79
+ position_ids,
80
+ attention_mask,
81
+ past_key_values,
82
+ inputs_embeds,
83
+ labels,
84
+ ) = self.prepare_inputs_labels_for_speech_and_text(
85
+ input_ids,
86
+ position_ids,
87
+ attention_mask,
88
+ past_key_values,
89
+ labels,
90
+ speech,
91
+ speech_lengths,
92
+ )
93
+
94
+ return super().forward(
95
+ input_ids=input_ids,
96
+ attention_mask=attention_mask,
97
+ position_ids=position_ids,
98
+ past_key_values=past_key_values,
99
+ inputs_embeds=inputs_embeds,
100
+ labels=labels,
101
+ use_cache=use_cache,
102
+ output_attentions=output_attentions,
103
+ output_hidden_states=output_hidden_states,
104
+ return_dict=return_dict,
105
+ )
106
+
107
+ @torch.no_grad()
108
+ def generate(
109
+ self,
110
+ inputs: Optional[torch.Tensor] = None,
111
+ speech: Optional[torch.Tensor] = None,
112
+ speech_lengths: Optional[torch.Tensor] = None,
113
+ **kwargs,
114
+ ) -> Union[GenerateOutput, torch.LongTensor]:
115
+ position_ids = kwargs.pop("position_ids", None)
116
+ attention_mask = kwargs.pop("attention_mask", None)
117
+ if "inputs_embeds" in kwargs:
118
+ raise NotImplementedError("`inputs_embeds` is not supported")
119
+
120
+ if speech is not None:
121
+ (
122
+ inputs,
123
+ position_ids,
124
+ attention_mask,
125
+ _,
126
+ inputs_embeds,
127
+ _,
128
+ ) = self.prepare_inputs_labels_for_speech_and_text(
129
+ inputs, position_ids, attention_mask, None, None, speech, speech_lengths
130
+ )
131
+ else:
132
+ inputs_embeds = self.get_model().embed_tokens(inputs)
133
+
134
+ return super().generate(
135
+ position_ids=position_ids,
136
+ attention_mask=attention_mask,
137
+ inputs_embeds=inputs_embeds,
138
+ **kwargs,
139
+ )
140
+
141
+ def prepare_inputs_for_generation(
142
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
143
+ ):
144
+ speech = kwargs.pop("speech", None)
145
+ speech_lengths = kwargs.pop("speech_lengths", None)
146
+ inputs = super().prepare_inputs_for_generation(
147
+ input_ids,
148
+ past_key_values=past_key_values,
149
+ inputs_embeds=inputs_embeds,
150
+ **kwargs,
151
+ )
152
+ if speech is not None:
153
+ inputs["speech"] = speech
154
+ inputs["speech_lengths"] = speech_lengths
155
+ return inputs
156
+
157
+
158
+ AutoConfig.register("egogpt_llama", EgoGPTConfig)
159
+ AutoModelForCausalLM.register(EgoGPTConfig, EgoGPTLlamaForCausalLM)
egogpt/model/language_model/egogpt_qwen.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import transformers
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoModelForCausalLM,
9
+ Qwen2Config,
10
+ Qwen2ForCausalLM,
11
+ Qwen2Model,
12
+ )
13
+ from transformers.generation.utils import GenerateOutput
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+
16
+ from ..egogpt_arch import EgoGPTMetaForCausalLM, EgoGPTMetaModel
17
+
18
+
19
+ class EgoGPTConfigQwen(Qwen2Config):
20
+ model_type = "egogpt_qwen"
21
+
22
+
23
+ class EgoGPTQwenModel(EgoGPTMetaModel, Qwen2Model):
24
+ config_class = EgoGPTConfigQwen
25
+
26
+ def __init__(self, config: Qwen2Config):
27
+ super(EgoGPTQwenModel, self).__init__(config)
28
+
29
+
30
+ class EgoGPTQwenForCausalLM(Qwen2ForCausalLM, EgoGPTMetaForCausalLM):
31
+ config_class = EgoGPTConfigQwen
32
+
33
+ def __init__(self, config):
34
+ super(Qwen2ForCausalLM, self).__init__(config)
35
+
36
+ config.rope_scaling = None
37
+ self.model = EgoGPTQwenModel(config)
38
+ self.vocab_size = config.vocab_size
39
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
40
+
41
+ # Initialize weights and apply final processing
42
+ self.post_init()
43
+
44
+ def get_model(self):
45
+ return self.model
46
+
47
+ def forward(
48
+ self,
49
+ input_ids: torch.LongTensor = None,
50
+ attention_mask: Optional[torch.Tensor] = None,
51
+ position_ids: Optional[torch.LongTensor] = None,
52
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
53
+ inputs_embeds: Optional[torch.FloatTensor] = None,
54
+ labels: Optional[torch.LongTensor] = None,
55
+ use_cache: Optional[bool] = None,
56
+ output_attentions: Optional[bool] = None,
57
+ output_hidden_states: Optional[bool] = None,
58
+ speech: Optional[torch.FloatTensor] = None,
59
+ speech_lengths: Optional[torch.LongTensor] = None,
60
+ images: Optional[torch.FloatTensor] = None,
61
+ image_sizes: Optional[List[List[int]]] = None,
62
+ modalities: Optional[List[str]] = ["image"],
63
+ return_dict: Optional[bool] = None,
64
+ cache_position: Optional[torch.LongTensor] = None,
65
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
66
+ if inputs_embeds is None:
67
+ (
68
+ input_ids,
69
+ position_ids,
70
+ attention_mask,
71
+ past_key_values,
72
+ inputs_embeds,
73
+ labels,
74
+ ) = self.prepare_inputs_labels_for_speech_and_text(
75
+ input_ids,
76
+ position_ids,
77
+ attention_mask,
78
+ past_key_values,
79
+ labels,
80
+ speech,
81
+ speech_lengths,
82
+ images,
83
+ image_sizes,
84
+ modalities,
85
+ )
86
+
87
+ return super().forward(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
+ position_ids=position_ids,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ labels=labels,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=output_hidden_states,
97
+ return_dict=return_dict,
98
+ )
99
+
100
+ @torch.no_grad()
101
+ def generate(
102
+ self,
103
+ inputs: Optional[torch.Tensor] = None,
104
+ speech: Optional[torch.Tensor] = None,
105
+ speech_lengths: Optional[torch.Tensor] = None,
106
+ images: Optional[torch.FloatTensor] = None,
107
+ image_sizes: Optional[List[List[int]]] = None,
108
+ modalities: Optional[List[str]] = ["image"],
109
+ **kwargs,
110
+ ) -> Union[GenerateOutput, torch.LongTensor]:
111
+ position_ids = kwargs.pop("position_ids", None)
112
+ attention_mask = kwargs.pop("attention_mask", None)
113
+ if "inputs_embeds" in kwargs:
114
+ raise NotImplementedError("`inputs_embeds` is not supported")
115
+
116
+ if speech is not None:
117
+ (
118
+ inputs,
119
+ position_ids,
120
+ attention_mask,
121
+ _,
122
+ inputs_embeds,
123
+ _,
124
+ ) = self.prepare_inputs_labels_for_speech_and_text(
125
+ inputs,
126
+ position_ids,
127
+ attention_mask,
128
+ None,
129
+ None,
130
+ speech,
131
+ speech_lengths,
132
+ images,
133
+ image_sizes,
134
+ modalities,
135
+ )
136
+ else:
137
+ inputs_embeds = self.get_model().embed_tokens(inputs)
138
+
139
+ return super().generate(
140
+ position_ids=position_ids,
141
+ attention_mask=attention_mask,
142
+ inputs_embeds=inputs_embeds,
143
+ **kwargs,
144
+ )
145
+
146
+ def prepare_inputs_for_generation(
147
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
148
+ ):
149
+ speech = kwargs.pop("speech", None)
150
+ speech_lengths = kwargs.pop("speech_lengths", None)
151
+ inputs = super().prepare_inputs_for_generation(
152
+ input_ids,
153
+ past_key_values=past_key_values,
154
+ inputs_embeds=inputs_embeds,
155
+ **kwargs,
156
+ )
157
+ if speech is not None:
158
+ inputs["speech"] = speech
159
+ inputs["speech_lengths"] = speech_lengths
160
+ return inputs
161
+
162
+
163
+ AutoConfig.register("egogpt_qwen", EgoGPTConfigQwen)
164
+ AutoModelForCausalLM.register(EgoGPTConfigQwen, EgoGPTQwenForCausalLM)
egogpt/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (914 Bytes). View file
 
egogpt/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc ADDED
Binary file (6.59 kB). View file
 
egogpt/model/multimodal_encoder/__pycache__/siglip_encoder.cpython-310.pyc ADDED
Binary file (22.2 kB). View file
 
egogpt/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
4
+ from .siglip_encoder import SigLipVisionTower
5
+
6
+ # from .eva_clip.eva_clip_encoder import EvaClipVisionTower
7
+ # from .dev_eva_clip.eva_vit import EvaViTWrapper
8
+
9
+
10
+ def build_vision_tower(vision_tower_cfg, **kwargs):
11
+ vision_tower = getattr(
12
+ vision_tower_cfg,
13
+ "mm_vision_tower",
14
+ getattr(vision_tower_cfg, "vision_tower", None),
15
+ )
16
+ is_absolute_path_exists = os.path.exists(vision_tower)
17
+ use_s2 = getattr(vision_tower_cfg, "s2", False)
18
+ # if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
19
+ if (
20
+ vision_tower.startswith("openai")
21
+ or vision_tower.startswith("laion")
22
+ or "ShareGPT4V" in vision_tower
23
+ ):
24
+ if use_s2:
25
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
26
+ else:
27
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
28
+ elif (
29
+ "siglip" in vision_tower.lower()
30
+ or "open_clip_pytorch_model.bin" in vision_tower
31
+ ):
32
+ return SigLipVisionTower(
33
+ vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs
34
+ )
35
+
36
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
egogpt/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
4
+
5
+ from egogpt.utils import rank0_print
6
+
7
+ try:
8
+ from s2wrapper import forward as multiscale_forward
9
+ except:
10
+ pass
11
+
12
+
13
+ class CLIPVisionTower(nn.Module):
14
+ def __init__(self, vision_tower, args, delay_load=False):
15
+ super().__init__()
16
+
17
+ self.is_loaded = False
18
+
19
+ self.vision_tower_name = vision_tower
20
+ self.select_layer = args.mm_vision_select_layer
21
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
22
+
23
+ if not delay_load:
24
+ rank0_print(f"Loading vision tower: {vision_tower}")
25
+ self.load_model()
26
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
27
+ # TODO: better detector is needed.
28
+ rank0_print(
29
+ f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True."
30
+ )
31
+ self.load_model()
32
+ elif (
33
+ hasattr(args, "mm_tunable_parts")
34
+ and "mm_vision_tower" in args.mm_tunable_parts
35
+ ):
36
+ rank0_print(
37
+ f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`."
38
+ )
39
+ self.load_model()
40
+ else:
41
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
42
+
43
+ def load_model(self, device_map=None):
44
+ if self.is_loaded:
45
+ rank0_print(
46
+ "{} is already loaded, `load_model` called again, skipping.".format(
47
+ self.vision_tower_name
48
+ )
49
+ )
50
+ return
51
+
52
+ self.image_processor = CLIPImageProcessor.from_pretrained(
53
+ self.vision_tower_name
54
+ )
55
+ self.vision_tower = CLIPVisionModel.from_pretrained(
56
+ self.vision_tower_name, device_map=device_map
57
+ )
58
+ self.vision_tower.requires_grad_(False)
59
+
60
+ self.is_loaded = True
61
+
62
+ def feature_select(self, image_forward_outs):
63
+ select_feature_type = self.select_feature
64
+
65
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
66
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
67
+ image_features = torch.cat(
68
+ [
69
+ image_forward_outs.hidden_states[i]
70
+ for i in range(
71
+ select_every_k_layer + self.select_layer,
72
+ len(image_forward_outs.hidden_states),
73
+ select_every_k_layer,
74
+ )
75
+ ],
76
+ dim=-1,
77
+ )
78
+ select_feature_type = select_feature_type.replace("slicefour_", "")
79
+ elif self.select_feature in [
80
+ "slice_m25811_f6_patch",
81
+ "slice_m25811_f6_cls_patch",
82
+ ]:
83
+ select_layers = [-2, -5, -8, -11, 6]
84
+ image_features = torch.cat(
85
+ [image_forward_outs.hidden_states[i] for i in select_layers], dim=-1
86
+ )
87
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
88
+ else:
89
+ image_features = image_forward_outs.hidden_states[self.select_layer]
90
+
91
+ if select_feature_type == "patch":
92
+ image_features = image_features[:, 1:]
93
+ elif select_feature_type == "cls_patch":
94
+ image_features = image_features
95
+ else:
96
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
97
+ return image_features
98
+
99
+ def forward(self, images):
100
+ if type(images) is list:
101
+ image_features = []
102
+ for image in images:
103
+ image_forward_out = self.vision_tower(
104
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
105
+ output_hidden_states=True,
106
+ )
107
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
108
+ image_features.append(image_feature)
109
+ else:
110
+ image_forward_outs = self.vision_tower(
111
+ images.to(device=self.device, dtype=self.dtype),
112
+ output_hidden_states=True,
113
+ )
114
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
115
+
116
+ return image_features
117
+
118
+ @property
119
+ def dummy_feature(self):
120
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
121
+
122
+ @property
123
+ def dtype(self):
124
+ return self.vision_tower.dtype
125
+
126
+ @property
127
+ def device(self):
128
+ return self.vision_tower.device
129
+
130
+ @property
131
+ def config(self):
132
+ if self.is_loaded:
133
+ return self.vision_tower.config
134
+ else:
135
+ return self.cfg_only
136
+
137
+ @property
138
+ def hidden_size(self):
139
+ _hidden_size = self.config.hidden_size
140
+ if "slicefour" in self.select_feature:
141
+ _hidden_size *= 4
142
+ if "slice_m25811_f6" in self.select_feature:
143
+ _hidden_size *= 5
144
+ return _hidden_size
145
+
146
+ @property
147
+ def num_patches_per_side(self):
148
+ return self.config.image_size // self.config.patch_size
149
+
150
+ @property
151
+ def num_patches(self):
152
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
153
+ if "cls_patch" in self.select_feature:
154
+ _num_patches += 1
155
+ return _num_patches
156
+
157
+ @property
158
+ def image_size(self):
159
+ return self.config.image_size
160
+
161
+
162
+ class CLIPVisionTowerS2(CLIPVisionTower):
163
+ def __init__(self, vision_tower, args, delay_load=False):
164
+ self.s2_scales = getattr(args, "s2_scales", "336,672,1008")
165
+ self.s2_scales = list(map(int, self.s2_scales.split(",")))
166
+ self.s2_scales.sort()
167
+ self.s2_split_size = self.s2_scales[0]
168
+ self.s2_image_size = self.s2_scales[-1]
169
+
170
+ super().__init__(vision_tower, args, delay_load)
171
+
172
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
173
+ if not delay_load or getattr(args, "unfreeze_mm_vision_tower", False):
174
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
175
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size[
176
+ "width"
177
+ ] = self.s2_image_size
178
+
179
+ def load_model(self, device_map=None):
180
+ if self.is_loaded:
181
+ rank0_print(
182
+ "{} is already loaded, `load_model` called again, skipping.".format(
183
+ self.vision_tower_name
184
+ )
185
+ )
186
+ return
187
+
188
+ self.image_processor = CLIPImageProcessor.from_pretrained(
189
+ self.vision_tower_name
190
+ )
191
+ self.vision_tower = CLIPVisionModel.from_pretrained(
192
+ self.vision_tower_name, device_map=device_map
193
+ )
194
+ self.vision_tower.requires_grad_(False)
195
+
196
+ self.image_processor.size["shortest_edge"] = self.s2_image_size
197
+ self.image_processor.crop_size["height"] = self.image_processor.crop_size[
198
+ "width"
199
+ ] = self.s2_image_size
200
+
201
+ self.is_loaded = True
202
+
203
+ def forward_feature(self, images):
204
+ image_forward_outs = self.vision_tower(
205
+ images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
206
+ )
207
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
208
+ return image_features
209
+
210
+ def forward(self, images):
211
+ if type(images) is list:
212
+ image_features = []
213
+ for image in images:
214
+ image_feature = multiscale_forward(
215
+ self.forward_feature,
216
+ image.unsqueeze(0),
217
+ img_sizes=self.s2_scales,
218
+ max_split_size=self.s2_split_size,
219
+ split_forward=True,
220
+ )
221
+ image_features.append(image_feature)
222
+ else:
223
+ image_features = multiscale_forward(
224
+ self.forward_feature,
225
+ images,
226
+ img_sizes=self.s2_scales,
227
+ max_split_size=self.s2_split_size,
228
+ split_forward=True,
229
+ )
230
+
231
+ return image_features
232
+
233
+ @property
234
+ def hidden_size(self):
235
+ return self.config.hidden_size * len(self.s2_scales)
egogpt/model/multimodal_encoder/siglip_encoder.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
3
+ """
4
+
5
+ import os
6
+ from dataclasses import dataclass
7
+ from functools import partial, reduce
8
+ from typing import Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.utils.checkpoint
12
+ from PIL import Image
13
+ from torch import nn
14
+ from transformers import PretrainedConfig
15
+ from transformers.activations import ACT2FN
16
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
17
+ from transformers.image_transforms import (
18
+ convert_to_rgb,
19
+ normalize,
20
+ rescale,
21
+ resize,
22
+ to_channel_dimension_format,
23
+ )
24
+ from transformers.image_utils import (
25
+ ChannelDimension,
26
+ PILImageResampling,
27
+ to_numpy_array,
28
+ )
29
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import ModelOutput
32
+
33
+ from egogpt.utils import rank0_print
34
+
35
+
36
+ class SigLipImageProcessor:
37
+ def __init__(
38
+ self,
39
+ image_mean=(0.5, 0.5, 0.5),
40
+ image_std=(0.5, 0.5, 0.5),
41
+ size=(384, 384),
42
+ crop_size: Dict[str, int] = None,
43
+ resample=PILImageResampling.BICUBIC,
44
+ rescale_factor=1 / 255,
45
+ data_format=ChannelDimension.FIRST,
46
+ ):
47
+ crop_size = (
48
+ crop_size if crop_size is not None else {"height": 384, "width": 384}
49
+ )
50
+ crop_size = get_size_dict(
51
+ crop_size, default_to_square=True, param_name="crop_size"
52
+ )
53
+
54
+ self.image_mean = image_mean
55
+ self.image_std = image_std
56
+ self.size = size
57
+ self.resample = resample
58
+ self.rescale_factor = rescale_factor
59
+ self.data_format = data_format
60
+ self.crop_size = crop_size
61
+
62
+ def preprocess(self, images, return_tensors):
63
+ if isinstance(images, Image.Image):
64
+ images = [images]
65
+ else:
66
+ # to adapt video data
67
+ images = [to_numpy_array(image) for image in images]
68
+ assert isinstance(images, list)
69
+
70
+ transforms = [
71
+ convert_to_rgb,
72
+ to_numpy_array,
73
+ partial(
74
+ resize,
75
+ size=self.size,
76
+ resample=self.resample,
77
+ data_format=self.data_format,
78
+ ),
79
+ partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
80
+ partial(
81
+ normalize,
82
+ mean=self.image_mean,
83
+ std=self.image_std,
84
+ data_format=self.data_format,
85
+ ),
86
+ partial(
87
+ to_channel_dimension_format,
88
+ channel_dim=self.data_format,
89
+ input_channel_dim=self.data_format,
90
+ ),
91
+ ]
92
+
93
+ images = reduce(lambda x, f: [*map(f, x)], transforms, images)
94
+ data = {"pixel_values": images}
95
+
96
+ return BatchFeature(data=data, tensor_type=return_tensors)
97
+
98
+
99
+ class SigLipVisionConfig(PretrainedConfig):
100
+ model_type = "siglip_vision_model"
101
+
102
+ def __init__(
103
+ self,
104
+ hidden_size=1152,
105
+ image_mean=(0.5, 0.5, 0.5),
106
+ intermediate_size=4304,
107
+ num_hidden_layers=27,
108
+ num_attention_heads=16,
109
+ num_channels=3,
110
+ image_size=384,
111
+ patch_size=14,
112
+ hidden_act="gelu_pytorch_tanh",
113
+ layer_norm_eps=1e-6,
114
+ attention_dropout=0.0,
115
+ **kwargs,
116
+ ):
117
+ super().__init__(**kwargs)
118
+
119
+ self.hidden_size = hidden_size
120
+ self.intermediate_size = intermediate_size
121
+ self.num_hidden_layers = num_hidden_layers
122
+ self.num_attention_heads = num_attention_heads
123
+ self.num_channels = num_channels
124
+ self.patch_size = patch_size
125
+ self.image_size = image_size
126
+ self.attention_dropout = attention_dropout
127
+ self.layer_norm_eps = layer_norm_eps
128
+ self.hidden_act = hidden_act
129
+ self.image_mean = image_mean
130
+
131
+ @classmethod
132
+ def from_pretrained(
133
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
134
+ ) -> "PretrainedConfig":
135
+ cls._set_token_in_kwargs(kwargs)
136
+
137
+ config_dict, kwargs = cls.get_config_dict(
138
+ pretrained_model_name_or_path, **kwargs
139
+ )
140
+
141
+ # get the vision config dict if we are loading from SigLipConfig
142
+ if config_dict.get("model_type") == "siglip":
143
+ config_dict = config_dict["vision_config"]
144
+
145
+ if (
146
+ "model_type" in config_dict
147
+ and hasattr(cls, "model_type")
148
+ and config_dict["model_type"] != cls.model_type
149
+ ):
150
+ print(
151
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
152
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
153
+ )
154
+
155
+ return cls.from_dict(config_dict, **kwargs)
156
+
157
+
158
+ @dataclass
159
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
160
+ class SigLipVisionModelOutput(ModelOutput):
161
+ """
162
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
163
+
164
+ Args:
165
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
166
+ The image embeddings obtained by applying the projection layer to the pooler_output.
167
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
168
+ Sequence of hidden-states at the output of the last layer of the model.
169
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
170
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
171
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
172
+
173
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
174
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
175
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
176
+ sequence_length)`.
177
+
178
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
179
+ heads.
180
+ """
181
+
182
+ image_embeds: Optional[torch.FloatTensor] = None
183
+ last_hidden_state: torch.FloatTensor = None
184
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
185
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
186
+
187
+
188
+ class SigLipVisionEmbeddings(nn.Module):
189
+ def __init__(self, config: SigLipVisionConfig):
190
+ super().__init__()
191
+ self.config = config
192
+ self.embed_dim = config.hidden_size
193
+ self.image_size = config.image_size
194
+ self.patch_size = config.patch_size
195
+
196
+ self.patch_embedding = nn.Conv2d(
197
+ in_channels=config.num_channels,
198
+ out_channels=self.embed_dim,
199
+ kernel_size=self.patch_size,
200
+ stride=self.patch_size,
201
+ padding="valid",
202
+ )
203
+
204
+ self.num_patches = (self.image_size // self.patch_size) ** 2
205
+ self.num_positions = self.num_patches
206
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
207
+ self.register_buffer(
208
+ "position_ids",
209
+ torch.arange(self.num_positions).expand((1, -1)),
210
+ persistent=False,
211
+ )
212
+
213
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
214
+ patch_embeds = self.patch_embedding(
215
+ pixel_values
216
+ ) # shape = [*, width, grid, grid]
217
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
218
+
219
+ embeddings = embeddings + self.position_embedding(self.position_ids)
220
+ return embeddings
221
+
222
+
223
+ class SigLipAttention(nn.Module):
224
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
225
+
226
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
227
+ def __init__(self, config):
228
+ super().__init__()
229
+ self.config = config
230
+ self.embed_dim = config.hidden_size
231
+ self.num_heads = config.num_attention_heads
232
+ self.head_dim = self.embed_dim // self.num_heads
233
+ if self.head_dim * self.num_heads != self.embed_dim:
234
+ raise ValueError(
235
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
236
+ f" {self.num_heads})."
237
+ )
238
+ self.scale = self.head_dim**-0.5
239
+ self.dropout = config.attention_dropout
240
+
241
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
242
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
243
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
244
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
245
+
246
+ def forward(
247
+ self,
248
+ hidden_states: torch.Tensor,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ output_attentions: Optional[bool] = False,
251
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
252
+ """Input shape: Batch x Time x Channel"""
253
+
254
+ batch_size, q_len, _ = hidden_states.size()
255
+
256
+ query_states = self.q_proj(hidden_states)
257
+ key_states = self.k_proj(hidden_states)
258
+ value_states = self.v_proj(hidden_states)
259
+
260
+ query_states = query_states.view(
261
+ batch_size, q_len, self.num_heads, self.head_dim
262
+ ).transpose(1, 2)
263
+ key_states = key_states.view(
264
+ batch_size, q_len, self.num_heads, self.head_dim
265
+ ).transpose(1, 2)
266
+ value_states = value_states.view(
267
+ batch_size, q_len, self.num_heads, self.head_dim
268
+ ).transpose(1, 2)
269
+
270
+ k_v_seq_len = key_states.shape[-2]
271
+ attn_weights = (
272
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
273
+ )
274
+
275
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
276
+ raise ValueError(
277
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
278
+ f" {attn_weights.size()}"
279
+ )
280
+
281
+ if attention_mask is not None:
282
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
283
+ raise ValueError(
284
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
285
+ )
286
+ attn_weights = attn_weights + attention_mask
287
+
288
+ # upcast attention to fp32
289
+ attn_weights = nn.functional.softmax(
290
+ attn_weights, dim=-1, dtype=torch.float32
291
+ ).to(query_states.dtype)
292
+ attn_weights = nn.functional.dropout(
293
+ attn_weights, p=self.dropout, training=self.training
294
+ )
295
+ attn_output = torch.matmul(attn_weights, value_states)
296
+
297
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
298
+ raise ValueError(
299
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
300
+ f" {attn_output.size()}"
301
+ )
302
+
303
+ attn_output = attn_output.transpose(1, 2).contiguous()
304
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
305
+
306
+ attn_output = self.out_proj(attn_output)
307
+
308
+ return attn_output, attn_weights
309
+
310
+
311
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
312
+ class SigLipMLP(nn.Module):
313
+ def __init__(self, config):
314
+ super().__init__()
315
+ self.config = config
316
+ self.activation_fn = ACT2FN[config.hidden_act]
317
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
318
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
319
+
320
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
321
+ hidden_states = self.fc1(hidden_states)
322
+ hidden_states = self.activation_fn(hidden_states)
323
+ hidden_states = self.fc2(hidden_states)
324
+ return hidden_states
325
+
326
+
327
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
328
+ class SigLipEncoderLayer(nn.Module):
329
+ def __init__(self, config: SigLipVisionConfig):
330
+ super().__init__()
331
+ self.embed_dim = config.hidden_size
332
+ self.self_attn = SigLipAttention(config)
333
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
334
+ self.mlp = SigLipMLP(config)
335
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
336
+
337
+ # Ignore copy
338
+ def forward(
339
+ self,
340
+ hidden_states: torch.Tensor,
341
+ attention_mask: torch.Tensor,
342
+ output_attentions: Optional[bool] = False,
343
+ ) -> Tuple[torch.FloatTensor]:
344
+ """
345
+ Args:
346
+ hidden_states (`torch.FloatTensor`):
347
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
348
+ attention_mask (`torch.FloatTensor`):
349
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
350
+ output_attentions (`bool`, *optional*, defaults to `False`):
351
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
352
+ returned tensors for more detail.
353
+ """
354
+ residual = hidden_states
355
+
356
+ hidden_states = self.layer_norm1(hidden_states)
357
+ hidden_states, attn_weights = self.self_attn(
358
+ hidden_states=hidden_states,
359
+ attention_mask=attention_mask,
360
+ output_attentions=output_attentions,
361
+ )
362
+ hidden_states = residual + hidden_states
363
+
364
+ residual = hidden_states
365
+ hidden_states = self.layer_norm2(hidden_states)
366
+ hidden_states = self.mlp(hidden_states)
367
+ hidden_states = residual + hidden_states
368
+
369
+ outputs = (hidden_states,)
370
+
371
+ if output_attentions:
372
+ outputs += (attn_weights,)
373
+
374
+ return outputs
375
+
376
+
377
+ class SigLipPreTrainedModel(PreTrainedModel):
378
+ """
379
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
380
+ models.
381
+ """
382
+
383
+ config_class = SigLipVisionConfig
384
+ base_model_prefix = "siglip"
385
+ supports_gradient_checkpointing = True
386
+
387
+ def _init_weights(self, module):
388
+ """Initialize the weights"""
389
+ pass
390
+
391
+
392
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
393
+ class SigLipEncoder(nn.Module):
394
+ """
395
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
396
+ [`SigLipEncoderLayer`].
397
+
398
+ Args:
399
+ config: SigLipVisionConfig
400
+ """
401
+
402
+ def __init__(self, config: SigLipVisionConfig):
403
+ super().__init__()
404
+ self.config = config
405
+ self.layers = nn.ModuleList(
406
+ [SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
407
+ )
408
+ self.gradient_checkpointing = False
409
+
410
+ # Ignore copy
411
+ def forward(
412
+ self,
413
+ inputs_embeds,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ output_attentions: Optional[bool] = None,
416
+ output_hidden_states: Optional[bool] = None,
417
+ return_dict: Optional[bool] = None,
418
+ ) -> Union[Tuple, BaseModelOutput]:
419
+ r"""
420
+ Args:
421
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
422
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
423
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
424
+ than the model's internal embedding lookup matrix.
425
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
426
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
427
+
428
+ - 1 for tokens that are **not masked**,
429
+ - 0 for tokens that are **masked**.
430
+
431
+ [What are attention masks?](../glossary#attention-mask)
432
+ output_attentions (`bool`, *optional*):
433
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
434
+ returned tensors for more detail.
435
+ output_hidden_states (`bool`, *optional*):
436
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
437
+ for more detail.
438
+ return_dict (`bool`, *optional*):
439
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
440
+ """
441
+ output_attentions = (
442
+ output_attentions
443
+ if output_attentions is not None
444
+ else self.config.output_attentions
445
+ )
446
+ output_hidden_states = (
447
+ output_hidden_states
448
+ if output_hidden_states is not None
449
+ else self.config.output_hidden_states
450
+ )
451
+ return_dict = (
452
+ return_dict if return_dict is not None else self.config.use_return_dict
453
+ )
454
+
455
+ encoder_states = () if output_hidden_states else None
456
+ all_attentions = () if output_attentions else None
457
+
458
+ hidden_states = inputs_embeds
459
+ for encoder_layer in self.layers:
460
+ if output_hidden_states:
461
+ encoder_states = encoder_states + (hidden_states,)
462
+ if self.gradient_checkpointing and self.training:
463
+ layer_outputs = self._gradient_checkpointing_func(
464
+ encoder_layer.__call__,
465
+ hidden_states,
466
+ attention_mask,
467
+ output_attentions,
468
+ )
469
+ else:
470
+ layer_outputs = encoder_layer(
471
+ hidden_states,
472
+ attention_mask,
473
+ output_attentions=output_attentions,
474
+ )
475
+
476
+ hidden_states = layer_outputs[0]
477
+
478
+ if output_attentions:
479
+ all_attentions = all_attentions + (layer_outputs[1],)
480
+
481
+ if output_hidden_states:
482
+ encoder_states = encoder_states + (hidden_states,)
483
+
484
+ if not return_dict:
485
+ return tuple(
486
+ v
487
+ for v in [hidden_states, encoder_states, all_attentions]
488
+ if v is not None
489
+ )
490
+ return BaseModelOutput(
491
+ last_hidden_state=hidden_states,
492
+ hidden_states=encoder_states,
493
+ attentions=all_attentions,
494
+ )
495
+
496
+
497
+ class SigLipVisionTransformer(nn.Module):
498
+ def __init__(self, config: SigLipVisionConfig):
499
+ super().__init__()
500
+ self.config = config
501
+ embed_dim = config.hidden_size
502
+
503
+ self.embeddings = SigLipVisionEmbeddings(config)
504
+ self.encoder = SigLipEncoder(config)
505
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
506
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
507
+
508
+ def forward(
509
+ self,
510
+ pixel_values,
511
+ output_attentions: Optional[bool] = None,
512
+ output_hidden_states: Optional[bool] = None,
513
+ return_dict: Optional[bool] = None,
514
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
515
+ r"""
516
+ Returns:
517
+
518
+ """
519
+ output_attentions = (
520
+ output_attentions
521
+ if output_attentions is not None
522
+ else self.config.output_attentions
523
+ )
524
+ output_hidden_states = (
525
+ output_hidden_states
526
+ if output_hidden_states is not None
527
+ else self.config.output_hidden_states
528
+ )
529
+ return_dict = (
530
+ return_dict if return_dict is not None else self.config.use_return_dict
531
+ )
532
+
533
+ hidden_states = self.embeddings(pixel_values)
534
+
535
+ encoder_outputs = self.encoder(
536
+ inputs_embeds=hidden_states,
537
+ output_attentions=output_attentions,
538
+ output_hidden_states=output_hidden_states,
539
+ return_dict=return_dict,
540
+ )
541
+
542
+ last_hidden_state = encoder_outputs[0]
543
+ last_hidden_state = self.post_layernorm(last_hidden_state)
544
+
545
+ pooled_output = self.head(last_hidden_state)
546
+
547
+ if not return_dict:
548
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
549
+
550
+ return BaseModelOutputWithPooling(
551
+ last_hidden_state=last_hidden_state,
552
+ pooler_output=pooled_output,
553
+ hidden_states=encoder_outputs.hidden_states,
554
+ attentions=encoder_outputs.attentions,
555
+ )
556
+
557
+
558
+ class SigLipMultiheadAttentionPoolingHead(nn.Module):
559
+ """Multihead Attention Pooling."""
560
+
561
+ def __init__(self, config: SigLipVisionConfig):
562
+ super().__init__()
563
+
564
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
565
+ self.attention = torch.nn.MultiheadAttention(
566
+ config.hidden_size, config.num_attention_heads, batch_first=True
567
+ )
568
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
569
+ self.mlp = SigLipMLP(config)
570
+
571
+ def forward(self, hidden_state):
572
+ batch_size = hidden_state.shape[0]
573
+ probe = self.probe.repeat(batch_size, 1, 1)
574
+
575
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
576
+
577
+ residual = hidden_state
578
+ hidden_state = self.layernorm(hidden_state)
579
+ hidden_state = residual + self.mlp(hidden_state)
580
+
581
+ return hidden_state[:, 0]
582
+
583
+
584
+ class SigLipVisionModel(SigLipPreTrainedModel):
585
+ config_class = SigLipVisionConfig
586
+ main_input_name = "pixel_values"
587
+ _no_split_modules = ["SigLipEncoderLayer"]
588
+
589
+ def __init__(self, config: SigLipVisionConfig):
590
+ super().__init__(config)
591
+
592
+ self.vision_model = SigLipVisionTransformer(config)
593
+
594
+ # Initialize weights and apply final processing
595
+ self.post_init()
596
+
597
+ def get_input_embeddings(self) -> nn.Module:
598
+ return self.vision_model.embeddings.patch_embedding
599
+
600
+ def forward(
601
+ self,
602
+ pixel_values,
603
+ output_attentions: Optional[bool] = None,
604
+ output_hidden_states: Optional[bool] = None,
605
+ return_dict: Optional[bool] = None,
606
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
607
+ r"""
608
+ Returns:
609
+
610
+ Examples:
611
+
612
+ ```python
613
+ >>> from PIL import Image
614
+ >>> import requests
615
+ >>> from transformers import AutoProcessor, SigLipVisionModel
616
+
617
+ >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
618
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
619
+
620
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
621
+ >>> image = Image.open(requests.get(url, stream=True).raw)
622
+
623
+ >>> inputs = processor(images=image, return_tensors="pt")
624
+
625
+ >>> outputs = model(**inputs)
626
+ >>> last_hidden_state = outputs.last_hidden_state
627
+ >>> pooled_output = outputs.pooler_output # pooled features
628
+ ```"""
629
+ return_dict = (
630
+ return_dict if return_dict is not None else self.config.use_return_dict
631
+ )
632
+
633
+ return self.vision_model(
634
+ pixel_values=pixel_values,
635
+ output_attentions=output_attentions,
636
+ output_hidden_states=output_hidden_states,
637
+ return_dict=return_dict,
638
+ )
639
+
640
+
641
+ class SigLipVisionTower(nn.Module):
642
+ def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
643
+ super().__init__()
644
+
645
+ self.is_loaded = False
646
+
647
+ self.config = SigLipVisionConfig()
648
+
649
+ self.vision_tower_name = vision_tower
650
+
651
+ self.image_processor = SigLipImageProcessor()
652
+
653
+ if not delay_load:
654
+ rank0_print(f"Loading vision tower: {vision_tower}")
655
+ self.load_model()
656
+ elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
657
+ # TODO: better detector is needed.
658
+ rank0_print(
659
+ f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True."
660
+ )
661
+ self.load_model()
662
+ elif (
663
+ hasattr(vision_tower_cfg, "mm_tunable_parts")
664
+ and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts
665
+ ):
666
+ rank0_print(
667
+ f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`."
668
+ )
669
+ self.load_model()
670
+ else:
671
+ self.cfg_only = self.config
672
+
673
+ def load_model(self, device_map=None):
674
+ if self.is_loaded:
675
+ rank0_print(
676
+ "{} is already loaded, `load_model` called again, skipping.".format(
677
+ self.vision_tower_name
678
+ )
679
+ )
680
+ return
681
+
682
+ self.vision_tower = SigLipVisionModel.from_pretrained(
683
+ self.vision_tower_name, device_map=device_map
684
+ )
685
+
686
+ del self.vision_tower.vision_model.encoder.layers[-1:]
687
+ self.vision_tower.vision_model.head = nn.Identity()
688
+ self.vision_tower.requires_grad_(False)
689
+
690
+ self.is_loaded = True
691
+
692
+ def forward(self, images):
693
+ if type(images) is list:
694
+ image_features = []
695
+ for image in images:
696
+ image_forward_out = self.vision_tower(
697
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
698
+ output_hidden_states=True,
699
+ )
700
+ image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
701
+ assert image_features.shape[-2] == 729
702
+ image_features.append(image_feature)
703
+ else:
704
+ image_forward_outs = self.vision_tower(
705
+ images.to(device=self.device, dtype=self.dtype),
706
+ output_hidden_states=True,
707
+ )
708
+ image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
709
+ assert image_features.shape[-2] == 729
710
+
711
+ return image_features
712
+
713
+ @property
714
+ def dummy_feature(self):
715
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
716
+
717
+ @property
718
+ def dtype(self):
719
+ for p in self.vision_tower.parameters():
720
+ return p.dtype
721
+
722
+ @property
723
+ def device(self):
724
+ for p in self.vision_tower.parameters():
725
+ return p.device
726
+
727
+ @property
728
+ def hidden_size(self):
729
+ return self.config.hidden_size
730
+
731
+ @property
732
+ def num_patches(self):
733
+ return (self.config.image_size // self.config.patch_size) ** 2
734
+
735
+ @property
736
+ def num_patches_per_side(self):
737
+ return self.config.image_size // self.config.patch_size
738
+ # return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]
739
+
740
+ @property
741
+ def image_size(self):
742
+ return self.config.image_size
egogpt/model/multimodal_projector/__pycache__/builder.cpython-310.pyc ADDED
Binary file (2.38 kB). View file
 
egogpt/model/multimodal_projector/__pycache__/pooler_projector.cpython-310.pyc ADDED
Binary file (1.44 kB). View file
 
egogpt/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .pooler_projector import PoolerProjector
7
+
8
+
9
+ class IdentityMap(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return x
15
+
16
+ @property
17
+ def config(self):
18
+ return {"mm_projector_type": "identity"}
19
+
20
+
21
+ class SimpleResBlock(nn.Module):
22
+ def __init__(self, channels):
23
+ super().__init__()
24
+ self.pre_norm = nn.LayerNorm(channels)
25
+
26
+ self.proj = nn.Sequential(
27
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
28
+ )
29
+
30
+ def forward(self, x):
31
+ x = self.pre_norm(x)
32
+ return x + self.proj(x)
33
+
34
+
35
+ def build_vision_projector(config, delay_load=False, **kwargs):
36
+ projector_type = getattr(config, "mm_projector_type", "linear")
37
+
38
+ if projector_type == "linear":
39
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
40
+
41
+ if projector_type == "pooler":
42
+ return PoolerProjector(config, kwargs["vision_cfg"])
43
+
44
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
45
+ if mlp_gelu_match:
46
+ mlp_depth = int(mlp_gelu_match.group(1))
47
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
48
+ for _ in range(1, mlp_depth):
49
+ modules.append(nn.GELU())
50
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
51
+ return nn.Sequential(*modules)
52
+
53
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
54
+ if mlp_gelu_resnet_match:
55
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
56
+ res_depth = int(mlp_gelu_resnet_match.group(2))
57
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
58
+ for _ in range(1, mlp_depth):
59
+ modules.append(nn.GELU())
60
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
61
+ for _ in range(res_depth):
62
+ modules.append(SimpleResBlock(config.hidden_size))
63
+ return nn.Sequential(*modules)
64
+
65
+ if projector_type == "identity":
66
+ return IdentityMap()
67
+
68
+ raise ValueError(f"Unknown projector type: {projector_type}")
egogpt/model/multimodal_projector/pooler_projector.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.models.clip.modeling_clip import CLIPVisionModel
6
+
7
+
8
+ class PoolerProjector(nn.Module):
9
+ def __init__(self, config, vision_cfg):
10
+ super().__init__()
11
+ self._config = config
12
+ self.hw = vision_cfg.image_size // vision_cfg.patch_size
13
+
14
+ self.conv_pool = nn.Conv2d(
15
+ config.mm_hidden_size, config.hidden_size, kernel_size=2, stride=2
16
+ )
17
+
18
+ self.proj = nn.Sequential(
19
+ nn.GELU(),
20
+ nn.Linear(config.hidden_size, config.hidden_size),
21
+ )
22
+
23
+ def forward(self, x, *args, **kwargs):
24
+ height = width = self.hw
25
+ assert height * width == x.shape[1]
26
+ x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2)
27
+ x = self.conv_pool(x)
28
+ x = x.flatten(2).transpose(1, 2)
29
+ x = self.proj(x)
30
+ return x
31
+
32
+ @property
33
+ def config(self):
34
+ return {"mm_projector_type": "pooler"}
egogpt/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
egogpt/model/multimodal_resampler/__pycache__/masked_drop.cpython-310.pyc ADDED
Binary file (2.46 kB). View file
 
egogpt/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc ADDED
Binary file (4.86 kB). View file
 
egogpt/model/multimodal_resampler/__pycache__/qformer.cpython-310.pyc ADDED
Binary file (32.9 kB). View file
 
egogpt/model/multimodal_resampler/__pycache__/spatial_pool.cpython-310.pyc ADDED
Binary file (1.9 kB). View file
 
egogpt/model/multimodal_resampler/builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .masked_drop import MaskedDrop
4
+ from .perceiver import PerceiverResampler
5
+ from .qformer import Qformer
6
+ from .spatial_pool import SpatialPool
7
+
8
+
9
+ class IdentityMap(torch.nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+ def forward(self, x, *args, **kwargs):
14
+ return x
15
+
16
+ @property
17
+ def config(self):
18
+ return {"mm_resampler_type": None}
19
+
20
+
21
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
22
+ resampler_type = getattr(model_args, "mm_resampler_type", None)
23
+ if resampler_type == "masked_drop":
24
+ return MaskedDrop(model_args)
25
+ elif resampler_type == "spatial_pool":
26
+ return SpatialPool(model_args, **kwargs)
27
+ elif resampler_type == "perceiver":
28
+ return PerceiverResampler(model_args, **kwargs)
29
+ elif resampler_type == "qformer":
30
+ return Qformer(model_args, **kwargs)
31
+ elif resampler_type is None:
32
+ return IdentityMap()
33
+
34
+ raise ValueError(f"Unknown resampler type: {resampler_type}")
egogpt/model/multimodal_resampler/masked_drop.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class MaskedDrop(nn.Module):
8
+ def __init__(self, model_args):
9
+ super().__init__()
10
+
11
+ self.mode = model_args.mm_mask_drop_mode
12
+ self.skip_percentage = model_args.mm_mask_drop_skip_percentage
13
+ self.ratio = model_args.mm_mask_drop_ratio
14
+ self.ratio_upper = model_args.mm_mask_drop_ratio_upper
15
+ self.ratio_lower = model_args.mm_mask_drop_ratio_lower
16
+
17
+ def forward(self, image_features, *args, **kwargs):
18
+ if not self.training:
19
+ return image_features
20
+
21
+ if self.skip_percentage > random.random():
22
+ return image_features
23
+
24
+ masked_features = []
25
+
26
+ for image_feature in image_features:
27
+ num_tokens = image_feature.shape[0]
28
+ if self.mode == "fixed":
29
+ num_keep = int(num_tokens * self.ratio)
30
+ masked_features.append(
31
+ self.random_masking(image_feature.unsqueeze(0), num_keep)[0][0]
32
+ )
33
+ elif self.mode == "range":
34
+ num_keep = int(
35
+ num_tokens * random.uniform(self.ratio_lower, self.ratio_upper)
36
+ )
37
+ masked_features.append(
38
+ self.random_masking(image_feature.unsqueeze(0), num_keep)[0]
39
+ )
40
+ elif self.mode == "cls_only":
41
+ masked_features.append(image_feature[0:1])
42
+ else:
43
+ raise ValueError(f"Unexpected masked drop mode: {self.mode}")
44
+
45
+ if self.mode not in ["range"] and (
46
+ type(image_features) is not list or self.mode in ["cls_only"]
47
+ ):
48
+ masked_features = torch.stack(masked_features, dim=0)
49
+
50
+ return masked_features
51
+
52
+ @property
53
+ def config(self):
54
+ return {
55
+ "mm_resampler_type": "masked_drop",
56
+ "mm_mask_drop_mode": self.mode,
57
+ "mm_mask_drop_skip_percentage": self.skip_percentage,
58
+ "mm_mask_drop_ratio": self.ratio,
59
+ "mm_mask_drop_ratio_upper": self.ratio_upper,
60
+ "mm_mask_drop_ratio_lower": self.ratio_lower,
61
+ }
62
+
63
+ def random_masking(self, x, len_keep):
64
+ """
65
+ Perform per-sample random masking by per-sample shuffling.
66
+ Per-sample shuffling is done by argsort random noise.
67
+ x: [N, L, D], sequence
68
+ """
69
+ N, L, D = x.shape # batch, length, dim
70
+
71
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
72
+
73
+ # sort noise for each sample
74
+ ids_shuffle = torch.argsort(
75
+ noise, dim=1
76
+ ) # ascend: small is keep, large is remove
77
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
78
+
79
+ # keep the first subset
80
+ ids_keep = ids_shuffle[:, :len_keep]
81
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
82
+
83
+ # generate the binary mask: 0 is keep, 1 is remove
84
+ mask = torch.ones([N, L], device=x.device)
85
+ mask[:, :len_keep] = 0
86
+ # unshuffle to get the binary mask
87
+ mask = torch.gather(mask, dim=1, index=ids_restore)
88
+
89
+ return x_masked, mask, ids_restore
egogpt/model/multimodal_resampler/perceiver.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+
8
+ try:
9
+ from einops_exts import rearrange_many
10
+ except:
11
+ pass
12
+
13
+ from torch import einsum, nn
14
+
15
+
16
+ def exists(val):
17
+ return val is not None
18
+
19
+
20
+ def FeedForward(dim, mult=4):
21
+ inner_dim = int(dim * mult)
22
+ return nn.Sequential(
23
+ nn.LayerNorm(dim),
24
+ nn.Linear(dim, inner_dim, bias=False),
25
+ nn.GELU(),
26
+ nn.Linear(inner_dim, dim, bias=False),
27
+ )
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm_media = nn.LayerNorm(dim)
38
+ self.norm_latents = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, T, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, T, n2, D)
51
+ """
52
+ x = self.norm_media(x)
53
+ latents = self.norm_latents(latents)
54
+
55
+ h = self.heads
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
61
+ q = q * self.scale
62
+
63
+ # attention
64
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
65
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
66
+ attn = sim.softmax(dim=-1)
67
+
68
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
69
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
70
+ return self.to_out(out)
71
+
72
+
73
+ class PerceiverResamplerModule(nn.Module):
74
+ def __init__(
75
+ self,
76
+ *,
77
+ dim,
78
+ depth=6,
79
+ dim_head=64,
80
+ heads=8,
81
+ num_latents=64,
82
+ max_num_media=None,
83
+ max_num_frames=None,
84
+ ff_mult=4,
85
+ ):
86
+ super().__init__()
87
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
88
+ self.frame_embs = (
89
+ nn.Parameter(torch.randn(max_num_frames, dim))
90
+ if exists(max_num_frames)
91
+ else None
92
+ )
93
+ self.media_time_embs = (
94
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
95
+ if exists(max_num_media)
96
+ else None
97
+ )
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult)
106
+ if ff_mult > 0
107
+ else nn.Identity(),
108
+ ]
109
+ )
110
+ )
111
+
112
+ self.norm = nn.LayerNorm(dim)
113
+
114
+ def forward(self, x):
115
+ """
116
+ Args:
117
+ x (torch.Tensor): image features
118
+ shape (b, T, F, v, D)
119
+ Returns:
120
+ shape (b, T, n, D) where n is self.num_latents
121
+ """
122
+ b, T, F, v = x.shape[:4]
123
+
124
+ # frame and media time embeddings
125
+ if exists(self.frame_embs):
126
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
127
+ x = x + frame_embs
128
+ x = rearrange(
129
+ x, "b T F v d -> b T (F v) d"
130
+ ) # flatten the frame and spatial dimensions
131
+ if exists(self.media_time_embs):
132
+ x = x + self.media_time_embs[:T]
133
+
134
+ # blocks
135
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
136
+ for attn, ff in self.layers:
137
+ latents = attn(x, latents) + latents
138
+ latents = ff(latents) + latents
139
+ return self.norm(latents)
140
+
141
+
142
+ class PerceiverResampler(nn.Module):
143
+ def __init__(self, model_args, vision_tower):
144
+ super().__init__()
145
+
146
+ self.depth = model_args.mm_perceiver_depth
147
+ self.num_latents = model_args.mm_perceiver_latents
148
+ self.ff_mult = model_args.mm_perceiver_ff_mult
149
+ self.pretrained = model_args.mm_perceiver_pretrained
150
+
151
+ self.perceiver = PerceiverResamplerModule(
152
+ dim=vision_tower.hidden_size,
153
+ depth=self.depth,
154
+ num_latents=self.num_latents,
155
+ ff_mult=self.ff_mult,
156
+ )
157
+
158
+ if self.pretrained is not None:
159
+ self.load_state_dict(torch.load(self.pretrained))
160
+
161
+ def forward(self, image_features, *args, **kwargs):
162
+ return self.perceiver(image_features[:, None, None]).squeeze(1)
163
+
164
+ @property
165
+ def config(self):
166
+ return {
167
+ "mm_resampler_type": "perceiver",
168
+ "mm_perceiver_depth": self.depth,
169
+ "mm_perceiver_latents": self.num_latents,
170
+ "mm_perceiver_ff_mult": self.ff_mult,
171
+ "mm_perceiver_pretrained": self.pretrained,
172
+ }
egogpt/model/multimodal_resampler/qformer.py ADDED
@@ -0,0 +1,1281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional, Tuple
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ import torch.utils.checkpoint
20
+ from torch import Tensor, device, dtype, nn
21
+ from torch.nn import CrossEntropyLoss
22
+ from transformers.activations import ACT2FN
23
+ from transformers.file_utils import ModelOutput
24
+ from transformers.modeling_outputs import (
25
+ BaseModelOutputWithPastAndCrossAttentions,
26
+ BaseModelOutputWithPoolingAndCrossAttentions,
27
+ CausalLMOutputWithCrossAttentions,
28
+ MaskedLMOutput,
29
+ MultipleChoiceModelOutput,
30
+ NextSentencePredictorOutput,
31
+ QuestionAnsweringModelOutput,
32
+ SequenceClassifierOutput,
33
+ TokenClassifierOutput,
34
+ )
35
+ from transformers.modeling_utils import (
36
+ PreTrainedModel,
37
+ apply_chunking_to_forward,
38
+ find_pruneable_heads_and_indices,
39
+ prune_linear_layer,
40
+ )
41
+ from transformers.models.bert.configuration_bert import BertConfig
42
+ from transformers.utils import logging
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ def disabled_train(self, mode=True):
48
+ """Overwrite model.train with this function to make sure train/eval mode
49
+ does not change anymore."""
50
+ return self
51
+
52
+
53
+ class BertEmbeddings(nn.Module):
54
+ """Construct the embeddings from word and position embeddings."""
55
+
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.word_embeddings = nn.Embedding(
59
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
60
+ )
61
+ self.position_embeddings = nn.Embedding(
62
+ config.max_position_embeddings, config.hidden_size
63
+ )
64
+
65
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
66
+ # any TensorFlow checkpoint file
67
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
68
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
69
+
70
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
71
+ self.register_buffer(
72
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
73
+ )
74
+ self.position_embedding_type = getattr(
75
+ config, "position_embedding_type", "absolute"
76
+ )
77
+
78
+ self.config = config
79
+
80
+ def forward(
81
+ self,
82
+ input_ids=None,
83
+ position_ids=None,
84
+ query_embeds=None,
85
+ past_key_values_length=0,
86
+ ):
87
+ if input_ids is not None:
88
+ seq_length = input_ids.size()[1]
89
+ else:
90
+ seq_length = 0
91
+
92
+ if position_ids is None:
93
+ position_ids = self.position_ids[
94
+ :, past_key_values_length : seq_length + past_key_values_length
95
+ ].clone()
96
+
97
+ if input_ids is not None:
98
+ embeddings = self.word_embeddings(input_ids)
99
+ if self.position_embedding_type == "absolute":
100
+ position_embeddings = self.position_embeddings(position_ids)
101
+ embeddings = embeddings + position_embeddings
102
+
103
+ if query_embeds is not None:
104
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
105
+ else:
106
+ embeddings = query_embeds
107
+
108
+ embeddings = self.LayerNorm(embeddings)
109
+ embeddings = self.dropout(embeddings)
110
+ return embeddings
111
+
112
+
113
+ class BertSelfAttention(nn.Module):
114
+ def __init__(self, config, is_cross_attention):
115
+ super().__init__()
116
+ self.config = config
117
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
118
+ config, "embedding_size"
119
+ ):
120
+ raise ValueError(
121
+ "The hidden size (%d) is not a multiple of the number of attention "
122
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
123
+ )
124
+
125
+ self.num_attention_heads = config.num_attention_heads
126
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
127
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
128
+
129
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
130
+ if is_cross_attention:
131
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
132
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
133
+ else:
134
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
135
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
136
+
137
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
138
+ self.position_embedding_type = getattr(
139
+ config, "position_embedding_type", "absolute"
140
+ )
141
+ if (
142
+ self.position_embedding_type == "relative_key"
143
+ or self.position_embedding_type == "relative_key_query"
144
+ ):
145
+ self.max_position_embeddings = config.max_position_embeddings
146
+ self.distance_embedding = nn.Embedding(
147
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
148
+ )
149
+ self.save_attention = False
150
+
151
+ def save_attn_gradients(self, attn_gradients):
152
+ self.attn_gradients = attn_gradients
153
+
154
+ def get_attn_gradients(self):
155
+ return self.attn_gradients
156
+
157
+ def save_attention_map(self, attention_map):
158
+ self.attention_map = attention_map
159
+
160
+ def get_attention_map(self):
161
+ return self.attention_map
162
+
163
+ def transpose_for_scores(self, x):
164
+ new_x_shape = x.size()[:-1] + (
165
+ self.num_attention_heads,
166
+ self.attention_head_size,
167
+ )
168
+ x = x.view(*new_x_shape)
169
+ return x.permute(0, 2, 1, 3)
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states,
174
+ attention_mask=None,
175
+ head_mask=None,
176
+ encoder_hidden_states=None,
177
+ encoder_attention_mask=None,
178
+ past_key_value=None,
179
+ output_attentions=False,
180
+ ):
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class BertSelfOutput(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
283
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
285
+
286
+ def forward(self, hidden_states, input_tensor):
287
+ hidden_states = self.dense(hidden_states)
288
+ hidden_states = self.dropout(hidden_states)
289
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
290
+ return hidden_states
291
+
292
+
293
+ class BertAttention(nn.Module):
294
+ def __init__(self, config, is_cross_attention=False):
295
+ super().__init__()
296
+ self.self = BertSelfAttention(config, is_cross_attention)
297
+ self.output = BertSelfOutput(config)
298
+ self.pruned_heads = set()
299
+
300
+ def prune_heads(self, heads):
301
+ if len(heads) == 0:
302
+ return
303
+ heads, index = find_pruneable_heads_and_indices(
304
+ heads,
305
+ self.self.num_attention_heads,
306
+ self.self.attention_head_size,
307
+ self.pruned_heads,
308
+ )
309
+
310
+ # Prune linear layers
311
+ self.self.query = prune_linear_layer(self.self.query, index)
312
+ self.self.key = prune_linear_layer(self.self.key, index)
313
+ self.self.value = prune_linear_layer(self.self.value, index)
314
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
315
+
316
+ # Update hyper params and store pruned heads
317
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
318
+ self.self.all_head_size = (
319
+ self.self.attention_head_size * self.self.num_attention_heads
320
+ )
321
+ self.pruned_heads = self.pruned_heads.union(heads)
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states,
326
+ attention_mask=None,
327
+ head_mask=None,
328
+ encoder_hidden_states=None,
329
+ encoder_attention_mask=None,
330
+ past_key_value=None,
331
+ output_attentions=False,
332
+ ):
333
+ self_outputs = self.self(
334
+ hidden_states,
335
+ attention_mask,
336
+ head_mask,
337
+ encoder_hidden_states,
338
+ encoder_attention_mask,
339
+ past_key_value,
340
+ output_attentions,
341
+ )
342
+ attention_output = self.output(self_outputs[0], hidden_states)
343
+
344
+ outputs = (attention_output,) + self_outputs[
345
+ 1:
346
+ ] # add attentions if we output them
347
+ return outputs
348
+
349
+
350
+ class BertIntermediate(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
354
+ if isinstance(config.hidden_act, str):
355
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
356
+ else:
357
+ self.intermediate_act_fn = config.hidden_act
358
+
359
+ def forward(self, hidden_states):
360
+ hidden_states = self.dense(hidden_states)
361
+ hidden_states = self.intermediate_act_fn(hidden_states)
362
+ return hidden_states
363
+
364
+
365
+ class BertOutput(nn.Module):
366
+ def __init__(self, config):
367
+ super().__init__()
368
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
369
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
370
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
371
+
372
+ def forward(self, hidden_states, input_tensor):
373
+ hidden_states = self.dense(hidden_states)
374
+ hidden_states = self.dropout(hidden_states)
375
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
376
+ return hidden_states
377
+
378
+
379
+ class BertLayer(nn.Module):
380
+ def __init__(self, config, layer_num):
381
+ super().__init__()
382
+ self.config = config
383
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
384
+ self.seq_len_dim = 1
385
+ self.attention = BertAttention(config)
386
+ self.layer_num = layer_num
387
+ if (
388
+ self.config.add_cross_attention
389
+ and layer_num % self.config.cross_attention_freq == 0
390
+ ):
391
+ self.crossattention = BertAttention(
392
+ config, is_cross_attention=self.config.add_cross_attention
393
+ )
394
+ self.has_cross_attention = True
395
+ else:
396
+ self.has_cross_attention = False
397
+ self.intermediate = BertIntermediate(config)
398
+ self.output = BertOutput(config)
399
+
400
+ self.intermediate_query = BertIntermediate(config)
401
+ self.output_query = BertOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states,
406
+ attention_mask=None,
407
+ head_mask=None,
408
+ encoder_hidden_states=None,
409
+ encoder_attention_mask=None,
410
+ past_key_value=None,
411
+ output_attentions=False,
412
+ query_length=0,
413
+ ):
414
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
415
+ self_attn_past_key_value = (
416
+ past_key_value[:2] if past_key_value is not None else None
417
+ )
418
+ self_attention_outputs = self.attention(
419
+ hidden_states,
420
+ attention_mask,
421
+ head_mask,
422
+ output_attentions=output_attentions,
423
+ past_key_value=self_attn_past_key_value,
424
+ )
425
+ attention_output = self_attention_outputs[0]
426
+ outputs = self_attention_outputs[1:-1]
427
+
428
+ present_key_value = self_attention_outputs[-1]
429
+
430
+ if query_length > 0:
431
+ query_attention_output = attention_output[:, :query_length, :]
432
+
433
+ if self.has_cross_attention:
434
+ assert (
435
+ encoder_hidden_states is not None
436
+ ), "encoder_hidden_states must be given for cross-attention layers"
437
+ cross_attention_outputs = self.crossattention(
438
+ query_attention_output,
439
+ attention_mask,
440
+ head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ output_attentions=output_attentions,
444
+ )
445
+ query_attention_output = cross_attention_outputs[0]
446
+ outputs = (
447
+ outputs + cross_attention_outputs[1:-1]
448
+ ) # add cross attentions if we output attention weights
449
+
450
+ layer_output = apply_chunking_to_forward(
451
+ self.feed_forward_chunk_query,
452
+ self.chunk_size_feed_forward,
453
+ self.seq_len_dim,
454
+ query_attention_output,
455
+ )
456
+ if attention_output.shape[1] > query_length:
457
+ layer_output_text = apply_chunking_to_forward(
458
+ self.feed_forward_chunk,
459
+ self.chunk_size_feed_forward,
460
+ self.seq_len_dim,
461
+ attention_output[:, query_length:, :],
462
+ )
463
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
464
+ else:
465
+ layer_output = apply_chunking_to_forward(
466
+ self.feed_forward_chunk,
467
+ self.chunk_size_feed_forward,
468
+ self.seq_len_dim,
469
+ attention_output,
470
+ )
471
+ outputs = (layer_output,) + outputs
472
+
473
+ outputs = outputs + (present_key_value,)
474
+
475
+ return outputs
476
+
477
+ def feed_forward_chunk(self, attention_output):
478
+ intermediate_output = self.intermediate(attention_output)
479
+ layer_output = self.output(intermediate_output, attention_output)
480
+ return layer_output
481
+
482
+ def feed_forward_chunk_query(self, attention_output):
483
+ intermediate_output = self.intermediate_query(attention_output)
484
+ layer_output = self.output_query(intermediate_output, attention_output)
485
+ return layer_output
486
+
487
+
488
+ class BertEncoder(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.config = config
492
+ self.layer = nn.ModuleList(
493
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
494
+ )
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states,
499
+ attention_mask=None,
500
+ head_mask=None,
501
+ encoder_hidden_states=None,
502
+ encoder_attention_mask=None,
503
+ past_key_values=None,
504
+ use_cache=None,
505
+ output_attentions=False,
506
+ output_hidden_states=False,
507
+ return_dict=True,
508
+ query_length=0,
509
+ ):
510
+ all_hidden_states = () if output_hidden_states else None
511
+ all_self_attentions = () if output_attentions else None
512
+ all_cross_attentions = (
513
+ () if output_attentions and self.config.add_cross_attention else None
514
+ )
515
+
516
+ next_decoder_cache = () if use_cache else None
517
+
518
+ for i in range(self.config.num_hidden_layers):
519
+ layer_module = self.layer[i]
520
+ if output_hidden_states:
521
+ all_hidden_states = all_hidden_states + (hidden_states,)
522
+
523
+ layer_head_mask = head_mask[i] if head_mask is not None else None
524
+ past_key_value = past_key_values[i] if past_key_values is not None else None
525
+
526
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
970
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
971
+
972
+ def __init__(self, config):
973
+ super().__init__(config)
974
+
975
+ self.bert = BertModel(config, add_pooling_layer=False)
976
+ self.cls = BertOnlyMLMHead(config)
977
+
978
+ self.init_weights()
979
+
980
+ def get_output_embeddings(self):
981
+ return self.cls.predictions.decoder
982
+
983
+ def set_output_embeddings(self, new_embeddings):
984
+ self.cls.predictions.decoder = new_embeddings
985
+
986
+ def forward(
987
+ self,
988
+ input_ids=None,
989
+ attention_mask=None,
990
+ position_ids=None,
991
+ head_mask=None,
992
+ query_embeds=None,
993
+ encoder_hidden_states=None,
994
+ encoder_attention_mask=None,
995
+ labels=None,
996
+ past_key_values=None,
997
+ use_cache=True,
998
+ output_attentions=None,
999
+ output_hidden_states=None,
1000
+ return_dict=None,
1001
+ return_logits=False,
1002
+ is_decoder=True,
1003
+ reduction="mean",
1004
+ ):
1005
+ r"""
1006
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1007
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1008
+ the model is configured as a decoder.
1009
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1010
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1011
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1012
+ - 1 for tokens that are **not masked**,
1013
+ - 0 for tokens that are **masked**.
1014
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1015
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1016
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1017
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1018
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1019
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1020
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1021
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1022
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1023
+ use_cache (:obj:`bool`, `optional`):
1024
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1025
+ decoding (see :obj:`past_key_values`).
1026
+ Returns:
1027
+ Example::
1028
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1029
+ >>> import torch
1030
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1031
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1032
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1033
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1034
+ >>> outputs = model(**inputs)
1035
+ >>> prediction_logits = outputs.logits
1036
+ """
1037
+ return_dict = (
1038
+ return_dict if return_dict is not None else self.config.use_return_dict
1039
+ )
1040
+ if labels is not None:
1041
+ use_cache = False
1042
+ if past_key_values is not None:
1043
+ query_embeds = None
1044
+
1045
+ outputs = self.bert(
1046
+ input_ids,
1047
+ attention_mask=attention_mask,
1048
+ position_ids=position_ids,
1049
+ head_mask=head_mask,
1050
+ query_embeds=query_embeds,
1051
+ encoder_hidden_states=encoder_hidden_states,
1052
+ encoder_attention_mask=encoder_attention_mask,
1053
+ past_key_values=past_key_values,
1054
+ use_cache=use_cache,
1055
+ output_attentions=output_attentions,
1056
+ output_hidden_states=output_hidden_states,
1057
+ return_dict=return_dict,
1058
+ is_decoder=is_decoder,
1059
+ )
1060
+
1061
+ sequence_output = outputs[0]
1062
+ if query_embeds is not None:
1063
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1064
+
1065
+ prediction_scores = self.cls(sequence_output)
1066
+
1067
+ if return_logits:
1068
+ return prediction_scores[:, :-1, :].contiguous()
1069
+
1070
+ lm_loss = None
1071
+ if labels is not None:
1072
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1073
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1074
+ labels = labels[:, 1:].contiguous()
1075
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1076
+ lm_loss = loss_fct(
1077
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1078
+ labels.view(-1),
1079
+ )
1080
+ if reduction == "none":
1081
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1082
+
1083
+ if not return_dict:
1084
+ output = (prediction_scores,) + outputs[2:]
1085
+ return ((lm_loss,) + output) if lm_loss is not None else output
1086
+
1087
+ return CausalLMOutputWithCrossAttentions(
1088
+ loss=lm_loss,
1089
+ logits=prediction_scores,
1090
+ past_key_values=outputs.past_key_values,
1091
+ hidden_states=outputs.hidden_states,
1092
+ attentions=outputs.attentions,
1093
+ cross_attentions=outputs.cross_attentions,
1094
+ )
1095
+
1096
+ def prepare_inputs_for_generation(
1097
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1098
+ ):
1099
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1100
+ if attention_mask is None:
1101
+ attention_mask = input_ids.new_ones(input_ids.shape)
1102
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1103
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1104
+
1105
+ # cut decoder_input_ids if past is used
1106
+ if past is not None:
1107
+ input_ids = input_ids[:, -1:]
1108
+
1109
+ return {
1110
+ "input_ids": input_ids,
1111
+ "query_embeds": query_embeds,
1112
+ "attention_mask": attention_mask,
1113
+ "past_key_values": past,
1114
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1115
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1116
+ "is_decoder": True,
1117
+ }
1118
+
1119
+ def _reorder_cache(self, past, beam_idx):
1120
+ reordered_past = ()
1121
+ for layer_past in past:
1122
+ reordered_past += (
1123
+ tuple(
1124
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1125
+ ),
1126
+ )
1127
+ return reordered_past
1128
+
1129
+
1130
+ class BertForMaskedLM(BertPreTrainedModel):
1131
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1132
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1133
+
1134
+ def __init__(self, config):
1135
+ super().__init__(config)
1136
+
1137
+ self.bert = BertModel(config, add_pooling_layer=False)
1138
+ self.cls = BertOnlyMLMHead(config)
1139
+
1140
+ self.init_weights()
1141
+
1142
+ def get_output_embeddings(self):
1143
+ return self.cls.predictions.decoder
1144
+
1145
+ def set_output_embeddings(self, new_embeddings):
1146
+ self.cls.predictions.decoder = new_embeddings
1147
+
1148
+ def forward(
1149
+ self,
1150
+ input_ids=None,
1151
+ attention_mask=None,
1152
+ position_ids=None,
1153
+ head_mask=None,
1154
+ query_embeds=None,
1155
+ encoder_hidden_states=None,
1156
+ encoder_attention_mask=None,
1157
+ labels=None,
1158
+ output_attentions=None,
1159
+ output_hidden_states=None,
1160
+ return_dict=None,
1161
+ return_logits=False,
1162
+ is_decoder=False,
1163
+ ):
1164
+ r"""
1165
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1166
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1167
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1168
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1169
+ """
1170
+
1171
+ return_dict = (
1172
+ return_dict if return_dict is not None else self.config.use_return_dict
1173
+ )
1174
+
1175
+ outputs = self.bert(
1176
+ input_ids,
1177
+ attention_mask=attention_mask,
1178
+ position_ids=position_ids,
1179
+ head_mask=head_mask,
1180
+ query_embeds=query_embeds,
1181
+ encoder_hidden_states=encoder_hidden_states,
1182
+ encoder_attention_mask=encoder_attention_mask,
1183
+ output_attentions=output_attentions,
1184
+ output_hidden_states=output_hidden_states,
1185
+ return_dict=return_dict,
1186
+ is_decoder=is_decoder,
1187
+ )
1188
+
1189
+ if query_embeds is not None:
1190
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1191
+ prediction_scores = self.cls(sequence_output)
1192
+
1193
+ if return_logits:
1194
+ return prediction_scores
1195
+
1196
+ masked_lm_loss = None
1197
+ if labels is not None:
1198
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1199
+ masked_lm_loss = loss_fct(
1200
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1201
+ )
1202
+
1203
+ if not return_dict:
1204
+ output = (prediction_scores,) + outputs[2:]
1205
+ return (
1206
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1207
+ )
1208
+
1209
+ return MaskedLMOutput(
1210
+ loss=masked_lm_loss,
1211
+ logits=prediction_scores,
1212
+ hidden_states=outputs.hidden_states,
1213
+ attentions=outputs.attentions,
1214
+ )
1215
+
1216
+
1217
+ class Qformer(nn.Module):
1218
+ def __init__(self, model_args, vision_tower):
1219
+ super().__init__()
1220
+
1221
+ self.depth = model_args.mm_qformer_depth
1222
+ self.num_latents = model_args.mm_qformer_latents
1223
+ self.pretrained = model_args.mm_qformer_pretrained
1224
+
1225
+ self.Qformer, self.query_tokens, self.ln_vision = self.build_Qformer(
1226
+ vision_tower.hidden_size, self.depth, self.num_latents
1227
+ )
1228
+
1229
+ if self.pretrained is not None:
1230
+ pretrained_dict = torch.load(self.pretrained, map_location="cpu")["model"]
1231
+ pretrained_dict = {
1232
+ k: v for k, v in pretrained_dict.items() if not k.startswith("t5_proj")
1233
+ }
1234
+ self.load_state_dict(pretrained_dict)
1235
+
1236
+ def build_Qformer(self, vision_width, cross_attention_freq, num_query_token):
1237
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
1238
+ encoder_config.encoder_width = vision_width
1239
+ # insert cross-attention layer every other block
1240
+ encoder_config.add_cross_attention = True
1241
+ encoder_config.cross_attention_freq = cross_attention_freq
1242
+ encoder_config.query_length = num_query_token
1243
+ Qformer = BertLMHeadModel(config=encoder_config)
1244
+ query_tokens = nn.Parameter(
1245
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
1246
+ )
1247
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
1248
+ Qformer.cls = None
1249
+ Qformer.bert.embeddings.word_embeddings = None
1250
+ Qformer.bert.embeddings.position_embeddings = None
1251
+ for layer in Qformer.bert.encoder.layer:
1252
+ layer.output = None
1253
+ layer.intermediate = None
1254
+ return Qformer, query_tokens, nn.LayerNorm(vision_width)
1255
+
1256
+ def forward(self, image_features, *args, **kwargs):
1257
+ x = self.ln_vision(image_features)
1258
+ image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
1259
+
1260
+ query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
1261
+ query_output = self.Qformer.bert(
1262
+ query_embeds=query_tokens,
1263
+ encoder_hidden_states=x,
1264
+ encoder_attention_mask=image_atts,
1265
+ return_dict=True,
1266
+ )
1267
+
1268
+ return query_output.last_hidden_state
1269
+
1270
+ @property
1271
+ def hidden_size(self):
1272
+ return 768
1273
+
1274
+ @property
1275
+ def config(self):
1276
+ return {
1277
+ "mm_resampler_type": "qformer",
1278
+ "mm_qformer_depth": self.depth,
1279
+ "mm_qformer_latents": self.num_latents,
1280
+ "mm_qformer_pretrained": self.pretrained,
1281
+ }
egogpt/model/multimodal_resampler/spatial_pool.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class SpatialPool(nn.Module):
8
+ def __init__(self, model_args, vision_tower):
9
+ super().__init__()
10
+
11
+ self.mode = model_args.mm_spatial_pool_mode
12
+ self.stride = model_args.mm_spatial_pool_stride
13
+ self.out_channels = getattr(
14
+ model_args, "mm_spatial_pool_out_channels", vision_tower.hidden_size
15
+ )
16
+
17
+ if self.mode == "average":
18
+ self.pool = nn.AvgPool2d(kernel_size=self.stride, stride=self.stride)
19
+ elif self.mode == "max":
20
+ self.pool = nn.MaxPool2d(kernel_size=self.stride, stride=self.stride)
21
+ elif self.mode == "conv":
22
+ self.pool = nn.Conv2d(
23
+ in_channels=vision_tower.hidden_size,
24
+ out_channels=self.out_channels,
25
+ kernel_size=self.stride,
26
+ stride=self.stride,
27
+ )
28
+ else:
29
+ raise ValueError(f"Unknown pooling mode: {self.pool}.")
30
+
31
+ def forward(self, image_features, images, *args, **kwargs):
32
+ ori_W = int(
33
+ math.sqrt(image_features.shape[1] * images.shape[3] // images.shape[2])
34
+ )
35
+ ori_H = int(ori_W * images.shape[2] // images.shape[3])
36
+
37
+ B, _, F = image_features.shape
38
+
39
+ image_features_spatial = image_features.view(B, ori_H, ori_H, F).permute(
40
+ 0, 3, 1, 2
41
+ )
42
+ image_features_spatial_pool = self.pool(image_features_spatial)
43
+
44
+ return image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous()
45
+
46
+ @property
47
+ def config(self):
48
+ return {
49
+ "mm_resampler_type": "spatial_pool",
50
+ "mm_spatial_pool_stride": self.stride,
51
+ "mm_spatial_pool_mode": self.mode,
52
+ "mm_spatial_pool_out_channels": self.out_channels,
53
+ }
54
+
55
+ @property
56
+ def hidden_size(self):
57
+ return self.out_channels
egogpt/model/speech_encoder/__pycache__/audio.cpython-310.pyc ADDED
Binary file (4.62 kB). View file
 
egogpt/model/speech_encoder/__pycache__/builder.cpython-310.pyc ADDED
Binary file (480 Bytes). View file
 
egogpt/model/speech_encoder/__pycache__/decoding.cpython-310.pyc ADDED
Binary file (26.1 kB). View file
 
egogpt/model/speech_encoder/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
egogpt/model/speech_encoder/__pycache__/speech_encoder.cpython-310.pyc ADDED
Binary file (5.68 kB). View file
 
egogpt/model/speech_encoder/__pycache__/timing.cpython-310.pyc ADDED
Binary file (9.68 kB). View file
 
egogpt/model/speech_encoder/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
egogpt/model/speech_encoder/__pycache__/transcribe.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
egogpt/model/speech_encoder/__pycache__/utils.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
egogpt/model/speech_encoder/audio.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from subprocess import CalledProcessError, run
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from .utils import exact_div
11
+
12
+ # hard-coded audio hyperparameters
13
+ SAMPLE_RATE = 16000
14
+ N_FFT = 400
15
+ HOP_LENGTH = 160
16
+ CHUNK_LENGTH = 30
17
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
18
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
19
+
20
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
21
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
22
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
23
+
24
+
25
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
26
+ """
27
+ Open an audio file and read as mono waveform, resampling as necessary
28
+
29
+ Parameters
30
+ ----------
31
+ file: str
32
+ The audio file to open
33
+
34
+ sr: int
35
+ The sample rate to resample the audio if necessary
36
+
37
+ Returns
38
+ -------
39
+ A NumPy array containing the audio waveform, in float32 dtype.
40
+ """
41
+
42
+ # This launches a subprocess to decode audio while down-mixing
43
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
44
+ # fmt: off
45
+ cmd = [
46
+ "ffmpeg",
47
+ "-nostdin",
48
+ "-threads", "0",
49
+ "-i", file,
50
+ "-f", "s16le",
51
+ "-ac", "1",
52
+ "-acodec", "pcm_s16le",
53
+ "-ar", str(sr),
54
+ "-"
55
+ ]
56
+ # fmt: on
57
+ try:
58
+ out = run(cmd, capture_output=True, check=True).stdout
59
+ except CalledProcessError as e:
60
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
61
+
62
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
63
+
64
+
65
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
66
+ """
67
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
68
+ """
69
+ if torch.is_tensor(array):
70
+ if array.shape[axis] > length:
71
+ array = array.index_select(
72
+ dim=axis, index=torch.arange(length, device=array.device)
73
+ )
74
+
75
+ if array.shape[axis] < length:
76
+ pad_widths = [(0, 0)] * array.ndim
77
+ pad_widths[axis] = (0, length - array.shape[axis])
78
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
79
+ else:
80
+ if array.shape[axis] > length:
81
+ array = array.take(indices=range(length), axis=axis)
82
+
83
+ if array.shape[axis] < length:
84
+ pad_widths = [(0, 0)] * array.ndim
85
+ pad_widths[axis] = (0, length - array.shape[axis])
86
+ array = np.pad(array, pad_widths)
87
+
88
+ return array
89
+
90
+
91
+ @lru_cache(maxsize=None)
92
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
93
+ """
94
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
95
+ Allows decoupling librosa dependency; saved using:
96
+
97
+ np.savez_compressed(
98
+ "mel_filters.npz",
99
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
100
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
101
+ )
102
+ """
103
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
104
+
105
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
106
+ with np.load(filters_path, allow_pickle=False) as f:
107
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
108
+
109
+
110
+ def log_mel_spectrogram(
111
+ audio: Union[str, np.ndarray, torch.Tensor],
112
+ n_mels: int = 80,
113
+ padding: int = 0,
114
+ device: Optional[Union[str, torch.device]] = None,
115
+ ):
116
+ """
117
+ Compute the log-Mel spectrogram of
118
+
119
+ Parameters
120
+ ----------
121
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
122
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
123
+
124
+ n_mels: int
125
+ The number of Mel-frequency filters, only 80 is supported
126
+
127
+ padding: int
128
+ Number of zero samples to pad to the right
129
+
130
+ device: Optional[Union[str, torch.device]]
131
+ If given, the audio tensor is moved to this device before STFT
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor, shape = (80, n_frames)
136
+ A Tensor that contains the Mel spectrogram
137
+ """
138
+ if not torch.is_tensor(audio):
139
+ if isinstance(audio, str):
140
+ audio = load_audio(audio)
141
+ audio = torch.from_numpy(audio)
142
+
143
+ if device is not None:
144
+ audio = audio.to(device)
145
+ if padding > 0:
146
+ audio = F.pad(audio, (0, padding))
147
+ window = torch.hann_window(N_FFT).to(audio.device)
148
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
149
+ magnitudes = stft[..., :-1].abs() ** 2
150
+
151
+ filters = mel_filters(audio.device, n_mels)
152
+ mel_spec = filters @ magnitudes
153
+
154
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
155
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
156
+ log_spec = (log_spec + 4.0) / 4.0
157
+ return log_spec
egogpt/model/speech_encoder/builder.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .speech_encoder import WhisperWrappedEncoder
2
+
3
+
4
+ def build_speech_encoder(config):
5
+ speech_encoder_type = getattr(config, "speech_encoder_type", None)
6
+ if "whisper" in speech_encoder_type.lower():
7
+ return WhisperWrappedEncoder(config)
8
+
9
+ raise ValueError(f"Unknown speech encoder: {speech_encoder_type}")
egogpt/model/speech_encoder/decoding.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, replace
2
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .model import Whisper
16
+
17
+
18
+ @torch.no_grad()
19
+ def detect_language(
20
+ model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
21
+ ) -> Tuple[Tensor, List[dict]]:
22
+ """
23
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
24
+ of the most probable language tokens and the probability distribution over all language tokens.
25
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
26
+
27
+ Returns
28
+ -------
29
+ language_tokens : Tensor, shape = (n_audio,)
30
+ ids of the most probable language tokens, which appears after the startoftranscript token.
31
+ language_probs : List[Dict[str, float]], length = n_audio
32
+ list of dictionaries containing the probability distribution over all languages.
33
+ """
34
+ if tokenizer is None:
35
+ tokenizer = get_tokenizer(
36
+ model.is_multilingual, num_languages=model.num_languages
37
+ )
38
+ if (
39
+ tokenizer.language is None
40
+ or tokenizer.language_token not in tokenizer.sot_sequence
41
+ ):
42
+ raise ValueError(
43
+ "This model doesn't have language tokens so it can't perform lang id"
44
+ )
45
+
46
+ single = mel.ndim == 2
47
+ if single:
48
+ mel = mel.unsqueeze(0)
49
+
50
+ # skip encoder forward pass if already-encoded audio features were given
51
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
52
+ mel = model.encoder(mel)
53
+
54
+ # forward pass using a single token, startoftranscript
55
+ n_audio = mel.shape[0]
56
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
57
+ logits = model.logits(x, mel)[:, 0]
58
+
59
+ # collect detected languages; suppress all non-language tokens
60
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
61
+ mask[list(tokenizer.all_language_tokens)] = False
62
+ logits[:, mask] = -np.inf
63
+ language_tokens = logits.argmax(dim=-1)
64
+ language_token_probs = logits.softmax(dim=-1).cpu()
65
+ language_probs = [
66
+ {
67
+ c: language_token_probs[i, j].item()
68
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
69
+ }
70
+ for i in range(n_audio)
71
+ ]
72
+
73
+ if single:
74
+ language_tokens = language_tokens[0]
75
+ language_probs = language_probs[0]
76
+
77
+ return language_tokens, language_probs
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class DecodingOptions:
82
+ # whether to perform X->X "transcribe" or X->English "translate"
83
+ task: str = "transcribe"
84
+
85
+ # language that the audio is in; uses detected language if None
86
+ language: Optional[str] = None
87
+
88
+ # sampling-related options
89
+ temperature: float = 0.0
90
+ sample_len: Optional[int] = None # maximum number of tokens to sample
91
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
92
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
93
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
94
+
95
+ # "alpha" in Google NMT, or None for length norm, when ranking generations
96
+ # to select which to return among the beams or best-of-N samples
97
+ length_penalty: Optional[float] = None
98
+
99
+ # text or tokens to feed as the prompt or the prefix; for more info:
100
+ # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
101
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
102
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
103
+
104
+ # list of tokens ids (or comma-separated token ids) to suppress
105
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
106
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
107
+ suppress_blank: bool = True # this will suppress blank outputs
108
+
109
+ # timestamp sampling options
110
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
111
+ max_initial_timestamp: Optional[float] = 1.0
112
+
113
+ # implementation details
114
+ fp16: bool = True # use fp16 for most of the calculation
115
+
116
+
117
+ @dataclass(frozen=True)
118
+ class DecodingResult:
119
+ audio_features: Tensor
120
+ language: str
121
+ language_probs: Optional[Dict[str, float]] = None
122
+ tokens: List[int] = field(default_factory=list)
123
+ text: str = ""
124
+ avg_logprob: float = np.nan
125
+ no_speech_prob: float = np.nan
126
+ temperature: float = np.nan
127
+ compression_ratio: float = np.nan
128
+
129
+
130
+ class Inference:
131
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
132
+ """Perform a forward pass on the decoder and return per-token logits"""
133
+ raise NotImplementedError
134
+
135
+ def rearrange_kv_cache(self, source_indices) -> None:
136
+ """Update the key-value cache according to the updated beams"""
137
+ raise NotImplementedError
138
+
139
+ def cleanup_caching(self) -> None:
140
+ """Clean up any resources or hooks after decoding is finished"""
141
+ pass
142
+
143
+
144
+ class PyTorchInference(Inference):
145
+ def __init__(self, model: "Whisper", initial_token_length: int):
146
+ self.model: "Whisper" = model
147
+ self.initial_token_length = initial_token_length
148
+ self.kv_cache = {}
149
+ self.hooks = []
150
+
151
+ key_modules = [block.attn.key for block in self.model.decoder.blocks]
152
+ value_modules = [block.attn.value for block in self.model.decoder.blocks]
153
+ self.kv_modules = key_modules + value_modules
154
+
155
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
156
+ if not self.kv_cache:
157
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
158
+
159
+ if tokens.shape[-1] > self.initial_token_length:
160
+ # only need to use the last token except in the first forward pass
161
+ tokens = tokens[:, -1:]
162
+
163
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
164
+
165
+ def cleanup_caching(self):
166
+ for hook in self.hooks:
167
+ hook.remove()
168
+
169
+ self.kv_cache = {}
170
+ self.hooks = []
171
+
172
+ def rearrange_kv_cache(self, source_indices):
173
+ if source_indices != list(range(len(source_indices))):
174
+ for module in self.kv_modules:
175
+ # update the key/value cache to contain the selected sequences
176
+ self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
177
+
178
+
179
+ class SequenceRanker:
180
+ def rank(
181
+ self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
182
+ ) -> List[int]:
183
+ """
184
+ Given a list of groups of samples and their cumulative log probabilities,
185
+ return the indices of the samples in each group to select as the final result
186
+ """
187
+ raise NotImplementedError
188
+
189
+
190
+ class MaximumLikelihoodRanker(SequenceRanker):
191
+ """
192
+ Select the sample with the highest log probabilities, penalized using either
193
+ a simple length normalization or Google NMT paper's length penalty
194
+ """
195
+
196
+ def __init__(self, length_penalty: Optional[float]):
197
+ self.length_penalty = length_penalty
198
+
199
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
200
+ def scores(logprobs, lengths):
201
+ result = []
202
+ for logprob, length in zip(logprobs, lengths):
203
+ if self.length_penalty is None:
204
+ penalty = length
205
+ else:
206
+ # from the Google NMT paper
207
+ penalty = ((5 + length) / 6) ** self.length_penalty
208
+ result.append(logprob / penalty)
209
+ return result
210
+
211
+ # get the sequence with the highest score
212
+ lengths = [[len(t) for t in s] for s in tokens]
213
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
214
+
215
+
216
+ class TokenDecoder:
217
+ def reset(self):
218
+ """Initialize any stateful variables for decoding a new sequence"""
219
+
220
+ def update(
221
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
222
+ ) -> Tuple[Tensor, bool]:
223
+ """Specify how to select the next token, based on the current trace and logits
224
+
225
+ Parameters
226
+ ----------
227
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
228
+ all tokens in the context so far, including the prefix and sot_sequence tokens
229
+
230
+ logits : Tensor, shape = (n_batch, vocab_size)
231
+ per-token logits of the probability distribution at the current step
232
+
233
+ sum_logprobs : Tensor, shape = (n_batch)
234
+ cumulative log probabilities for each sequence
235
+
236
+ Returns
237
+ -------
238
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
239
+ the tokens, appended with the selected next token
240
+
241
+ completed : bool
242
+ True if all sequences has reached the end of text
243
+
244
+ """
245
+ raise NotImplementedError
246
+
247
+ def finalize(
248
+ self, tokens: Tensor, sum_logprobs: Tensor
249
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
250
+ """Finalize search and return the final candidate sequences
251
+
252
+ Parameters
253
+ ----------
254
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
255
+ all tokens in the context so far, including the prefix and sot_sequence
256
+
257
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
258
+ cumulative log probabilities for each sequence
259
+
260
+ Returns
261
+ -------
262
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
263
+ sequence of Tensors containing candidate token sequences, for each audio input
264
+
265
+ sum_logprobs : List[List[float]], length = n_audio
266
+ sequence of cumulative log probabilities corresponding to the above
267
+
268
+ """
269
+ raise NotImplementedError
270
+
271
+
272
+ class GreedyDecoder(TokenDecoder):
273
+ def __init__(self, temperature: float, eot: int):
274
+ self.temperature = temperature
275
+ self.eot = eot
276
+
277
+ def update(
278
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
279
+ ) -> Tuple[Tensor, bool]:
280
+ if self.temperature == 0:
281
+ next_tokens = logits.argmax(dim=-1)
282
+ else:
283
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
284
+
285
+ logprobs = F.log_softmax(logits.float(), dim=-1)
286
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
287
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
288
+
289
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
290
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
291
+
292
+ completed = (tokens[:, -1] == self.eot).all()
293
+ return tokens, completed
294
+
295
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
296
+ # make sure each sequence has at least one EOT token at the end
297
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
298
+ return tokens, sum_logprobs.tolist()
299
+
300
+
301
+ class BeamSearchDecoder(TokenDecoder):
302
+ def __init__(
303
+ self,
304
+ beam_size: int,
305
+ eot: int,
306
+ inference: Inference,
307
+ patience: Optional[float] = None,
308
+ ):
309
+ self.beam_size = beam_size
310
+ self.eot = eot
311
+ self.inference = inference
312
+ self.patience = patience or 1.0
313
+ self.max_candidates: int = round(beam_size * self.patience)
314
+ self.finished_sequences = None
315
+
316
+ assert (
317
+ self.max_candidates > 0
318
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
319
+
320
+ def reset(self):
321
+ self.finished_sequences = None
322
+
323
+ def update(
324
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
325
+ ) -> Tuple[Tensor, bool]:
326
+ if tokens.shape[0] % self.beam_size != 0:
327
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
328
+
329
+ n_audio = tokens.shape[0] // self.beam_size
330
+ if self.finished_sequences is None: # for the first update
331
+ self.finished_sequences = [{} for _ in range(n_audio)]
332
+
333
+ logprobs = F.log_softmax(logits.float(), dim=-1)
334
+ next_tokens, source_indices, finished_sequences = [], [], []
335
+ for i in range(n_audio):
336
+ scores, sources, finished = {}, {}, {}
337
+
338
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
339
+ for j in range(self.beam_size):
340
+ idx = i * self.beam_size + j
341
+ prefix = tokens[idx].tolist()
342
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
343
+ new_logprob = (sum_logprobs[idx] + logprob).item()
344
+ sequence = tuple(prefix + [token.item()])
345
+ scores[sequence] = new_logprob
346
+ sources[sequence] = idx
347
+
348
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
349
+ saved = 0
350
+ for sequence in sorted(scores, key=scores.get, reverse=True):
351
+ if sequence[-1] == self.eot:
352
+ finished[sequence] = scores[sequence]
353
+ else:
354
+ sum_logprobs[len(next_tokens)] = scores[sequence]
355
+ next_tokens.append(sequence)
356
+ source_indices.append(sources[sequence])
357
+
358
+ saved += 1
359
+ if saved == self.beam_size:
360
+ break
361
+
362
+ finished_sequences.append(finished)
363
+
364
+ tokens = torch.tensor(next_tokens, device=tokens.device)
365
+ self.inference.rearrange_kv_cache(source_indices)
366
+
367
+ # add newly finished sequences to self.finished_sequences
368
+ assert len(self.finished_sequences) == len(finished_sequences)
369
+ for previously_finished, newly_finished in zip(
370
+ self.finished_sequences, finished_sequences
371
+ ):
372
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
373
+ if len(previously_finished) >= self.max_candidates:
374
+ break # the candidate list is full
375
+ previously_finished[seq] = newly_finished[seq]
376
+
377
+ # mark as completed if all audio has enough number of samples
378
+ completed = all(
379
+ len(sequences) >= self.max_candidates
380
+ for sequences in self.finished_sequences
381
+ )
382
+ return tokens, completed
383
+
384
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
385
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
386
+ sum_logprobs = sum_logprobs.cpu()
387
+ for i, sequences in enumerate(self.finished_sequences):
388
+ if (
389
+ len(sequences) < self.beam_size
390
+ ): # when not enough sequences are finished
391
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
392
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
393
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
394
+ if len(sequences) >= self.beam_size:
395
+ break
396
+
397
+ tokens: List[List[Tensor]] = [
398
+ [torch.tensor(seq) for seq in sequences.keys()]
399
+ for sequences in self.finished_sequences
400
+ ]
401
+ sum_logprobs: List[List[float]] = [
402
+ list(sequences.values()) for sequences in self.finished_sequences
403
+ ]
404
+ return tokens, sum_logprobs
405
+
406
+
407
+ class LogitFilter:
408
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
409
+ """Apply any filtering or masking to logits in-place
410
+
411
+ Parameters
412
+ ----------
413
+ logits : Tensor, shape = (n_batch, vocab_size)
414
+ per-token logits of the probability distribution at the current step
415
+
416
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
417
+ all tokens in the context so far, including the prefix and sot_sequence tokens
418
+
419
+ """
420
+ raise NotImplementedError
421
+
422
+
423
+ class SuppressBlank(LogitFilter):
424
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
425
+ self.tokenizer = tokenizer
426
+ self.sample_begin = sample_begin
427
+
428
+ def apply(self, logits: Tensor, tokens: Tensor):
429
+ if tokens.shape[1] == self.sample_begin:
430
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
431
+
432
+
433
+ class SuppressTokens(LogitFilter):
434
+ def __init__(self, suppress_tokens: Sequence[int]):
435
+ self.suppress_tokens = list(suppress_tokens)
436
+
437
+ def apply(self, logits: Tensor, tokens: Tensor):
438
+ logits[:, self.suppress_tokens] = -np.inf
439
+
440
+
441
+ class ApplyTimestampRules(LogitFilter):
442
+ def __init__(
443
+ self,
444
+ tokenizer: Tokenizer,
445
+ sample_begin: int,
446
+ max_initial_timestamp_index: Optional[int],
447
+ ):
448
+ self.tokenizer = tokenizer
449
+ self.sample_begin = sample_begin
450
+ self.max_initial_timestamp_index = max_initial_timestamp_index
451
+
452
+ def apply(self, logits: Tensor, tokens: Tensor):
453
+ # suppress <|notimestamps|> which is handled by without_timestamps
454
+ if self.tokenizer.no_timestamps is not None:
455
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
456
+
457
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
458
+ for k in range(tokens.shape[0]):
459
+ sampled_tokens = tokens[k, self.sample_begin :]
460
+ seq = [t for t in sampled_tokens.tolist()]
461
+ last_was_timestamp = (
462
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
463
+ )
464
+ penultimate_was_timestamp = (
465
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
466
+ )
467
+
468
+ if last_was_timestamp:
469
+ if penultimate_was_timestamp: # has to be non-timestamp
470
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
471
+ else: # cannot be normal text tokens
472
+ logits[k, : self.tokenizer.eot] = -np.inf
473
+
474
+ timestamps = sampled_tokens[
475
+ sampled_tokens.ge(self.tokenizer.timestamp_begin)
476
+ ]
477
+ if timestamps.numel() > 0:
478
+ # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
479
+ # also force each segment to have a nonzero length, to prevent infinite looping
480
+ if last_was_timestamp and not penultimate_was_timestamp:
481
+ timestamp_last = timestamps[-1]
482
+ else:
483
+ timestamp_last = timestamps[-1] + 1
484
+ logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
485
+
486
+ if tokens.shape[1] == self.sample_begin:
487
+ # suppress generating non-timestamp tokens at the beginning
488
+ logits[:, : self.tokenizer.timestamp_begin] = -np.inf
489
+
490
+ # apply the `max_initial_timestamp` option
491
+ if self.max_initial_timestamp_index is not None:
492
+ last_allowed = (
493
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
494
+ )
495
+ logits[:, last_allowed + 1 :] = -np.inf
496
+
497
+ # if sum of probability over timestamps is above any other token, sample timestamp
498
+ logprobs = F.log_softmax(logits.float(), dim=-1)
499
+ for k in range(tokens.shape[0]):
500
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
501
+ dim=-1
502
+ )
503
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
504
+ if timestamp_logprob > max_text_token_logprob:
505
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
506
+
507
+
508
+ class DecodingTask:
509
+ inference: Inference
510
+ sequence_ranker: SequenceRanker
511
+ decoder: TokenDecoder
512
+ logit_filters: List[LogitFilter]
513
+
514
+ def __init__(self, model: "Whisper", options: DecodingOptions):
515
+ self.model = model
516
+
517
+ language = options.language or "en"
518
+ tokenizer = get_tokenizer(
519
+ model.is_multilingual,
520
+ num_languages=model.num_languages,
521
+ language=language,
522
+ task=options.task,
523
+ )
524
+ self.tokenizer: Tokenizer = tokenizer
525
+ self.options: DecodingOptions = self._verify_options(options)
526
+
527
+ self.n_group: int = options.beam_size or options.best_of or 1
528
+ self.n_ctx: int = model.dims.n_text_ctx
529
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
530
+
531
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
532
+ if self.options.without_timestamps:
533
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
534
+
535
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
536
+ self.sample_begin: int = len(self.initial_tokens)
537
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
538
+
539
+ # inference: implements the forward pass through the decoder, including kv caching
540
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
541
+
542
+ # sequence ranker: implements how to rank a group of sampled sequences
543
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
544
+
545
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
546
+ if options.beam_size is not None:
547
+ self.decoder = BeamSearchDecoder(
548
+ options.beam_size, tokenizer.eot, self.inference, options.patience
549
+ )
550
+ else:
551
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
552
+
553
+ # logit filters: applies various rules to suppress or penalize certain tokens
554
+ self.logit_filters = []
555
+ if self.options.suppress_blank:
556
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
557
+ if self.options.suppress_tokens:
558
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
559
+ if not options.without_timestamps:
560
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
561
+ max_initial_timestamp_index = None
562
+ if options.max_initial_timestamp:
563
+ max_initial_timestamp_index = round(
564
+ self.options.max_initial_timestamp / precision
565
+ )
566
+ self.logit_filters.append(
567
+ ApplyTimestampRules(
568
+ tokenizer, self.sample_begin, max_initial_timestamp_index
569
+ )
570
+ )
571
+
572
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
573
+ if options.beam_size is not None and options.best_of is not None:
574
+ raise ValueError("beam_size and best_of can't be given together")
575
+ if options.temperature == 0:
576
+ if options.best_of is not None:
577
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
578
+ if options.patience is not None and options.beam_size is None:
579
+ raise ValueError("patience requires beam_size to be given")
580
+ if options.length_penalty is not None and not (
581
+ 0 <= options.length_penalty <= 1
582
+ ):
583
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
584
+
585
+ return options
586
+
587
+ def _get_initial_tokens(self) -> Tuple[int]:
588
+ tokens = list(self.sot_sequence)
589
+
590
+ if prefix := self.options.prefix:
591
+ prefix_tokens = (
592
+ self.tokenizer.encode(" " + prefix.strip())
593
+ if isinstance(prefix, str)
594
+ else prefix
595
+ )
596
+ if self.sample_len is not None:
597
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
598
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
599
+ tokens = tokens + prefix_tokens
600
+
601
+ if prompt := self.options.prompt:
602
+ prompt_tokens = (
603
+ self.tokenizer.encode(" " + prompt.strip())
604
+ if isinstance(prompt, str)
605
+ else prompt
606
+ )
607
+ tokens = (
608
+ [self.tokenizer.sot_prev]
609
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
610
+ + tokens
611
+ )
612
+
613
+ return tuple(tokens)
614
+
615
+ def _get_suppress_tokens(self) -> Tuple[int]:
616
+ suppress_tokens = self.options.suppress_tokens
617
+
618
+ if isinstance(suppress_tokens, str):
619
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
620
+
621
+ if -1 in suppress_tokens:
622
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
623
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
624
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
625
+ suppress_tokens = [] # interpret empty string as an empty list
626
+ else:
627
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
628
+
629
+ suppress_tokens.extend(
630
+ [
631
+ self.tokenizer.transcribe,
632
+ self.tokenizer.translate,
633
+ self.tokenizer.sot,
634
+ self.tokenizer.sot_prev,
635
+ self.tokenizer.sot_lm,
636
+ ]
637
+ )
638
+ if self.tokenizer.no_speech is not None:
639
+ # no-speech probability is collected separately
640
+ suppress_tokens.append(self.tokenizer.no_speech)
641
+
642
+ return tuple(sorted(set(suppress_tokens)))
643
+
644
+ def _get_audio_features(self, mel: Tensor):
645
+ if self.options.fp16:
646
+ mel = mel.half()
647
+
648
+ if mel.shape[-2:] == (
649
+ self.model.dims.n_audio_ctx,
650
+ self.model.dims.n_audio_state,
651
+ ):
652
+ # encoded audio features are given; skip audio encoding
653
+ audio_features = mel
654
+ else:
655
+ audio_features = self.model.encoder(mel)
656
+
657
+ if audio_features.dtype != (
658
+ torch.float16 if self.options.fp16 else torch.float32
659
+ ):
660
+ return TypeError(
661
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
662
+ )
663
+
664
+ return audio_features
665
+
666
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
667
+ languages = [self.options.language] * audio_features.shape[0]
668
+ lang_probs = None
669
+
670
+ if self.options.language is None or self.options.task == "lang_id":
671
+ lang_tokens, lang_probs = self.model.detect_language(
672
+ audio_features, self.tokenizer
673
+ )
674
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
675
+ if self.options.language is None:
676
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
677
+
678
+ return languages, lang_probs
679
+
680
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
681
+ n_batch = tokens.shape[0]
682
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
683
+ no_speech_probs = [np.nan] * n_batch
684
+
685
+ try:
686
+ for i in range(self.sample_len):
687
+ logits = self.inference.logits(tokens, audio_features)
688
+
689
+ if (
690
+ i == 0 and self.tokenizer.no_speech is not None
691
+ ): # save no_speech_probs
692
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
693
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
694
+
695
+ # now we need to consider the logits at the last token only
696
+ logits = logits[:, -1]
697
+
698
+ # apply the logit filters, e.g. for suppressing or applying penalty to
699
+ for logit_filter in self.logit_filters:
700
+ logit_filter.apply(logits, tokens)
701
+
702
+ # expand the tokens tensor with the selected next tokens
703
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
704
+
705
+ if completed or tokens.shape[-1] > self.n_ctx:
706
+ break
707
+ finally:
708
+ self.inference.cleanup_caching()
709
+
710
+ return tokens, sum_logprobs, no_speech_probs
711
+
712
+ @torch.no_grad()
713
+ def run(self, mel: Tensor) -> List[DecodingResult]:
714
+ self.decoder.reset()
715
+ tokenizer: Tokenizer = self.tokenizer
716
+ n_audio: int = mel.shape[0]
717
+
718
+ audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
719
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
720
+
721
+ # detect language if requested, overwriting the language token
722
+ languages, language_probs = self._detect_language(audio_features, tokens)
723
+ if self.options.task == "lang_id":
724
+ return [
725
+ DecodingResult(
726
+ audio_features=features, language=language, language_probs=probs
727
+ )
728
+ for features, language, probs in zip(
729
+ audio_features, languages, language_probs
730
+ )
731
+ ]
732
+
733
+ # repeat text tensors by the group size, for beam search or best-of-n sampling
734
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
735
+
736
+ # call the main sampling loop
737
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
738
+
739
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
740
+ audio_features = audio_features[:: self.n_group]
741
+ no_speech_probs = no_speech_probs[:: self.n_group]
742
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
743
+
744
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
745
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
746
+
747
+ # get the final candidates for each group, and slice between the first sampled token and EOT
748
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
749
+ tokens: List[List[Tensor]] = [
750
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
751
+ for s in tokens
752
+ ]
753
+
754
+ # select the top-ranked sample in each group
755
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
756
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
757
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
758
+
759
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
760
+ avg_logprobs: List[float] = [
761
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
762
+ ]
763
+
764
+ fields = (
765
+ texts,
766
+ languages,
767
+ tokens,
768
+ audio_features,
769
+ avg_logprobs,
770
+ no_speech_probs,
771
+ )
772
+ if len(set(map(len, fields))) != 1:
773
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
774
+
775
+ return [
776
+ DecodingResult(
777
+ audio_features=features,
778
+ language=language,
779
+ tokens=tokens,
780
+ text=text,
781
+ avg_logprob=avg_logprob,
782
+ no_speech_prob=no_speech_prob,
783
+ temperature=self.options.temperature,
784
+ compression_ratio=compression_ratio(text),
785
+ )
786
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
787
+ *fields
788
+ )
789
+ ]
790
+
791
+
792
+ @torch.no_grad()
793
+ def decode(
794
+ model: "Whisper",
795
+ mel: Tensor,
796
+ options: DecodingOptions = DecodingOptions(),
797
+ **kwargs,
798
+ ) -> Union[DecodingResult, List[DecodingResult]]:
799
+ """
800
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
801
+
802
+ Parameters
803
+ ----------
804
+ model: Whisper
805
+ the Whisper model instance
806
+
807
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
808
+ A tensor containing the Mel spectrogram(s)
809
+
810
+ options: DecodingOptions
811
+ A dataclass that contains all necessary options for decoding 30-second segments
812
+
813
+ Returns
814
+ -------
815
+ result: Union[DecodingResult, List[DecodingResult]]
816
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
817
+ """
818
+ if single := mel.ndim == 2:
819
+ mel = mel.unsqueeze(0)
820
+
821
+ if kwargs:
822
+ options = replace(options, **kwargs)
823
+
824
+ result = DecodingTask(model, options).run(mel)
825
+
826
+ return result[0] if single else result
egogpt/model/speech_encoder/model.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gzip
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Iterable, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import Tensor, nn
11
+
12
+ from .decoding import decode as decode_function
13
+ from .decoding import detect_language as detect_language_function
14
+ from .transcribe import transcribe as transcribe_function
15
+
16
+ try:
17
+ from torch.nn.functional import scaled_dot_product_attention
18
+
19
+ SDPA_AVAILABLE = True
20
+ except (ImportError, RuntimeError, OSError):
21
+ scaled_dot_product_attention = None
22
+ SDPA_AVAILABLE = False
23
+
24
+
25
+ @dataclass
26
+ class ModelDimensions:
27
+ n_mels: int
28
+ n_audio_ctx: int
29
+ n_audio_state: int
30
+ n_audio_head: int
31
+ n_audio_layer: int
32
+ n_vocab: int
33
+ n_text_ctx: int
34
+ n_text_state: int
35
+ n_text_head: int
36
+ n_text_layer: int
37
+
38
+
39
+ class LayerNorm(nn.LayerNorm):
40
+ def forward(self, x: Tensor) -> Tensor:
41
+ return super().forward(x).type(x.dtype) # Choiszt fix
42
+
43
+
44
+ class Linear(nn.Linear):
45
+ def forward(self, x: Tensor) -> Tensor:
46
+ return F.linear(
47
+ x,
48
+ self.weight.to(x.dtype),
49
+ None if self.bias is None else self.bias.to(x.dtype),
50
+ )
51
+
52
+
53
+ class Conv1d(nn.Conv1d):
54
+ def _conv_forward(
55
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
56
+ ) -> Tensor:
57
+ return super()._conv_forward(
58
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
59
+ )
60
+
61
+
62
+ def sinusoids(length, channels, max_timescale=10000):
63
+ """Returns sinusoids for positional embedding"""
64
+ assert channels % 2 == 0
65
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
66
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
67
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
68
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
69
+
70
+
71
+ @contextmanager
72
+ def disable_sdpa():
73
+ prev_state = MultiHeadAttention.use_sdpa
74
+ try:
75
+ MultiHeadAttention.use_sdpa = False
76
+ yield
77
+ finally:
78
+ MultiHeadAttention.use_sdpa = prev_state
79
+
80
+
81
+ class MultiHeadAttention(nn.Module):
82
+ use_sdpa = True
83
+
84
+ def __init__(self, n_state: int, n_head: int):
85
+ super().__init__()
86
+ self.n_head = n_head
87
+ self.query = Linear(n_state, n_state)
88
+ self.key = Linear(n_state, n_state, bias=False)
89
+ self.value = Linear(n_state, n_state)
90
+ self.out = Linear(n_state, n_state)
91
+
92
+ def forward(
93
+ self,
94
+ x: Tensor,
95
+ xa: Optional[Tensor] = None,
96
+ mask: Optional[Tensor] = None,
97
+ kv_cache: Optional[dict] = None,
98
+ ):
99
+ q = self.query(x)
100
+
101
+ if kv_cache is None or xa is None or self.key not in kv_cache:
102
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
103
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
104
+ k = self.key(x if xa is None else xa)
105
+ v = self.value(x if xa is None else xa)
106
+ else:
107
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
108
+ k = kv_cache[self.key]
109
+ v = kv_cache[self.value]
110
+
111
+ wv, qk = self.qkv_attention(q, k, v, mask)
112
+ return self.out(wv), qk
113
+
114
+ def qkv_attention(
115
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
116
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
117
+ n_batch, n_ctx, n_state = q.shape
118
+ scale = (n_state // self.n_head) ** -0.25
119
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
120
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
121
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
122
+
123
+ if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
124
+ a = scaled_dot_product_attention(
125
+ q, k, v, is_causal=mask is not None and n_ctx > 1
126
+ )
127
+ out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
128
+ qk = None
129
+ else:
130
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
131
+ if mask is not None:
132
+ qk = qk + mask[:n_ctx, :n_ctx]
133
+ qk = qk.float()
134
+
135
+ w = F.softmax(qk, dim=-1).to(q.dtype)
136
+ out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
137
+ qk = qk.detach()
138
+
139
+ return out, qk
140
+
141
+
142
+ class ResidualAttentionBlock(nn.Module):
143
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
144
+ super().__init__()
145
+
146
+ self.attn = MultiHeadAttention(n_state, n_head)
147
+ self.attn_ln = LayerNorm(n_state)
148
+
149
+ self.cross_attn = (
150
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
151
+ )
152
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
153
+
154
+ n_mlp = n_state * 4
155
+ self.mlp = nn.Sequential(
156
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
157
+ )
158
+ self.mlp_ln = LayerNorm(n_state)
159
+
160
+ def forward(
161
+ self,
162
+ x: Tensor,
163
+ xa: Optional[Tensor] = None,
164
+ mask: Optional[Tensor] = None,
165
+ kv_cache: Optional[dict] = None,
166
+ ):
167
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
168
+ if self.cross_attn:
169
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
170
+ x = x + self.mlp(self.mlp_ln(x))
171
+ return x
172
+
173
+
174
+ class AudioEncoder(nn.Module):
175
+ def __init__(
176
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
177
+ ):
178
+ super().__init__()
179
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
180
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
181
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
182
+
183
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
184
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
185
+ )
186
+ self.ln_post = LayerNorm(n_state)
187
+
188
+ def forward(self, x: Tensor):
189
+ """
190
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
191
+ the mel spectrogram of the audio
192
+ """
193
+ x = F.gelu(self.conv1(x))
194
+ x = F.gelu(self.conv2(x))
195
+ x = x.permute(0, 2, 1)
196
+
197
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
198
+ x = (x + self.positional_embedding).to(x.dtype)
199
+
200
+ for block in self.blocks:
201
+ x = block(x)
202
+
203
+ x = self.ln_post(x)
204
+ return x
205
+
206
+
207
+ class TextDecoder(nn.Module):
208
+ def __init__(
209
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
210
+ ):
211
+ super().__init__()
212
+
213
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
214
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
215
+
216
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
217
+ [
218
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
219
+ for _ in range(n_layer)
220
+ ]
221
+ )
222
+ self.ln = LayerNorm(n_state)
223
+
224
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
225
+ self.register_buffer("mask", mask, persistent=False)
226
+
227
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
228
+ """
229
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
230
+ the text tokens
231
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
232
+ the encoded audio features to be attended on
233
+ """
234
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
235
+ x = (
236
+ self.token_embedding(x)
237
+ + self.positional_embedding[offset : offset + x.shape[-1]]
238
+ )
239
+ x = x.to(xa.dtype)
240
+
241
+ for block in self.blocks:
242
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
243
+
244
+ x = self.ln(x)
245
+ logits = (
246
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
247
+ ).float()
248
+
249
+ return logits
250
+
251
+
252
+ class Whisper(nn.Module):
253
+ def __init__(self, dims: ModelDimensions):
254
+ super().__init__()
255
+ self.dims = dims
256
+ self.encoder = AudioEncoder(
257
+ self.dims.n_mels,
258
+ self.dims.n_audio_ctx,
259
+ self.dims.n_audio_state,
260
+ self.dims.n_audio_head,
261
+ self.dims.n_audio_layer,
262
+ )
263
+ self.decoder = TextDecoder(
264
+ self.dims.n_vocab,
265
+ self.dims.n_text_ctx,
266
+ self.dims.n_text_state,
267
+ self.dims.n_text_head,
268
+ self.dims.n_text_layer,
269
+ )
270
+ # use the last half among the decoder layers for time alignment by default;
271
+ # to use a specific set of heads, see `set_alignment_heads()` below.
272
+ all_heads = torch.zeros(
273
+ self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
274
+ )
275
+ all_heads[self.dims.n_text_layer // 2 :] = True
276
+ self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
277
+
278
+ def set_alignment_heads(self, dump: bytes):
279
+ array = np.frombuffer(
280
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
281
+ ).copy()
282
+ mask = torch.from_numpy(array).reshape(
283
+ self.dims.n_text_layer, self.dims.n_text_head
284
+ )
285
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
286
+
287
+ def embed_audio(self, mel: torch.Tensor):
288
+ return self.encoder(mel)
289
+
290
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
291
+ return self.decoder(tokens, audio_features)
292
+
293
+ def forward(
294
+ self, mel: torch.Tensor, tokens: torch.Tensor
295
+ ) -> Dict[str, torch.Tensor]:
296
+ return self.decoder(tokens, self.encoder(mel))
297
+
298
+ @property
299
+ def device(self):
300
+ return next(self.parameters()).device
301
+
302
+ @property
303
+ def is_multilingual(self):
304
+ return self.dims.n_vocab >= 51865
305
+
306
+ @property
307
+ def num_languages(self):
308
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
309
+
310
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
311
+ """
312
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
313
+ tensors calculated for the previous positions. This method returns a dictionary that stores
314
+ all caches, and the necessary hooks for the key and value projection modules that save the
315
+ intermediate tensors to be reused during later calculations.
316
+
317
+ Returns
318
+ -------
319
+ cache : Dict[nn.Module, torch.Tensor]
320
+ A dictionary object mapping the key/value projection modules to its cache
321
+ hooks : List[RemovableHandle]
322
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
323
+ """
324
+ cache = {**cache} if cache is not None else {}
325
+ hooks = []
326
+
327
+ def save_to_cache(module, _, output):
328
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
329
+ # save as-is, for the first token or cross attention
330
+ cache[module] = output
331
+ else:
332
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
333
+ return cache[module]
334
+
335
+ def install_hooks(layer: nn.Module):
336
+ if isinstance(layer, MultiHeadAttention):
337
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
338
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
339
+
340
+ self.decoder.apply(install_hooks)
341
+ return cache, hooks
342
+
343
+ detect_language = detect_language_function
344
+ transcribe = transcribe_function
345
+ decode = decode_function