aisyahhrazak commited on
Commit
af53a9f
1 Parent(s): bb4e038

Upload MM_LLMs

Browse files
Files changed (3) hide show
  1. config.json +6 -2
  2. model.safetensors +2 -2
  3. modeling.py +385 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "multimodal-tinyllama-whisper-small-siglip/checkpoint-100",
3
  "architectures": [
4
  "MM_LLMs"
5
  ],
@@ -206,6 +206,10 @@
206
  },
207
  "audio_conv_kernel": 240,
208
  "audio_conv_stride": 220,
 
 
 
 
209
  "hidden_size": 2048,
210
  "image_config": {
211
  "_name_or_path": "google/siglip-base-patch16-224",
@@ -503,6 +507,6 @@
503
  },
504
  "model_type": "mm_llms",
505
  "n_frames": 6,
506
- "torch_dtype": "float16",
507
  "transformers_version": "4.37.1"
508
  }
 
1
  {
2
+ "_name_or_path": "multimodal-tinyllama-whisper-small-siglip/checkpoint-450",
3
  "architectures": [
4
  "MM_LLMs"
5
  ],
 
206
  },
207
  "audio_conv_kernel": 240,
208
  "audio_conv_stride": 220,
209
+ "auto_map": {
210
+ "AutoConfig": "modeling.MM_LLMs_Config",
211
+ "AutoModel": "modeling.MM_LLMs"
212
+ },
213
  "hidden_size": 2048,
214
  "image_config": {
215
  "_name_or_path": "google/siglip-base-patch16-224",
 
507
  },
508
  "model_type": "mm_llms",
509
  "n_frames": 6,
510
+ "torch_dtype": "bfloat16",
511
  "transformers_version": "4.37.1"
512
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f2fcfbcd7d6fb5a4f2750f5dc80b6540c3ed664c95587aaeb3113b260d766062
3
- size 3509161510
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2d41892dc7281cb5a6667ac3d6f59cd55dc4bdb2b62832347fe305ae63a40af
3
+ size 3509162622
modeling.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ import copy
9
+ import math
10
+ from typing import List, Optional, Tuple, Union
11
+ from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
12
+ from transformers import CONFIG_MAPPING
13
+ from transformers.modeling_outputs import BaseModelOutput
14
+ from transformers import GenerationConfig
15
+ from transformers import CLIPConfig, CLIPProcessor, CLIPModel, AutoModel
16
+ from transformers import WhisperConfig, WhisperPreTrainedModel, WhisperModel
17
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig
18
+
19
+
20
+ def most_frequent_element(tensor):
21
+ flattened_list = tensor.flatten().tolist()
22
+ counter = Counter(flattened_list)
23
+ most_common_element = counter.most_common(1)[0][1]
24
+
25
+ return most_common_element
26
+
27
+
28
+ class MM_LLMs_Config(PretrainedConfig):
29
+ model_type = 'mm_llms'
30
+ is_composition = True
31
+
32
+ def __init__(self, attention_heads=8, image_conv_kernel=48, image_conv_stride=36,
33
+ audio_conv_kernel=240, audio_conv_stride=220,
34
+ image_config=None, audio_config=None, llm_config=None, **kwargs):
35
+
36
+ self.image_config = image_config
37
+ self.audio_config = audio_config
38
+ self.llm_config = llm_config
39
+ self.attention_heads = attention_heads
40
+ self.image_conv_kernel = image_conv_kernel
41
+ self.image_conv_stride = image_conv_stride
42
+ self.audio_conv_kernel = audio_conv_kernel
43
+ self.audio_conv_stride = audio_conv_stride
44
+
45
+ if isinstance(self.image_config, dict):
46
+ image_config["model_type"] = (
47
+ image_config["model_type"] if "model_type" in image_config else "clip"
48
+ )
49
+ self.image_config = CONFIG_MAPPING[image_config["model_type"]](**image_config)
50
+ if isinstance(self.audio_config, dict):
51
+ audio_config["model_type"] = (
52
+ audio_config["model_type"] if "model_type" in audio_config else "whisper"
53
+ )
54
+ self.audio_config = CONFIG_MAPPING[audio_config["model_type"]](**audio_config)
55
+ if isinstance(self.llm_config, dict):
56
+ llm_config["model_type"] = llm_config["model_type"] if "model_type" in llm_config else "llama"
57
+ self.llm_config = CONFIG_MAPPING[llm_config["model_type"]](**llm_config)
58
+
59
+ self.hidden_size = max(
60
+ self.llm_config.hidden_size,
61
+ self.image_config.vision_config.hidden_size,
62
+ self.audio_config.d_model,
63
+ )
64
+
65
+ super().__init__(**kwargs)
66
+
67
+
68
+ class MM_LLMs(PreTrainedModel):
69
+ config_class = MM_LLMs_Config
70
+ supports_gradient_checkpointing = True
71
+ _supports_flash_attn_2 = True
72
+
73
+ def __init__(self, config):
74
+ super().__init__(config)
75
+ self.config = config
76
+
77
+ self.image_encoder = AutoModel.from_config(config.image_config)
78
+
79
+ self.audio_encoder = AutoModel.from_config(config.audio_config)
80
+
81
+ self.llm = AutoModelForCausalLM.from_config(config.llm_config)
82
+
83
+ attn_dropout = 0.1
84
+ is_add_bias_kv = True
85
+ is_add_zero_attn = True
86
+ self.temporal_self_attention = nn.MultiheadAttention(
87
+ config.image_config.text_config.hidden_size,
88
+ config.attention_heads,
89
+ dropout=attn_dropout,
90
+ add_bias_kv=is_add_bias_kv,
91
+ add_zero_attn=is_add_zero_attn)
92
+
93
+ self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
94
+ config.attention_heads * 2,
95
+ dropout=attn_dropout,
96
+ add_bias_kv=is_add_bias_kv,
97
+ add_zero_attn=is_add_zero_attn)
98
+
99
+ self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
100
+ config.attention_heads * 2,
101
+ dropout=attn_dropout,
102
+ add_bias_kv=is_add_bias_kv,
103
+ add_zero_attn=is_add_zero_attn)
104
+
105
+ self.transform_audio_to_hidden = nn.Linear(config.audio_config.d_model,
106
+ config.llm_config.hidden_size)
107
+ self.transform_image_to_hidden = nn.Linear(config.image_config.text_config.hidden_size,
108
+ config.llm_config.hidden_size)
109
+
110
+ self.project_image = nn.Conv1d(
111
+ config.image_config.text_config.hidden_size,
112
+ config.image_config.text_config.hidden_size,
113
+ kernel_size=config.image_conv_kernel,
114
+ stride=config.image_conv_stride)
115
+ self.project_audio = nn.Conv1d(
116
+ config.audio_config.d_model,
117
+ config.audio_config.d_model,
118
+ kernel_size=config.audio_conv_kernel,
119
+ stride=config.audio_conv_stride)
120
+
121
+ self.visual_projection = nn.Linear(
122
+ self.image_encoder.vision_model.config.hidden_size,
123
+ self.config.image_config.text_config.hidden_size,
124
+ bias=False)
125
+
126
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
127
+
128
+ self.layer_norm = nn.LayerNorm(config.image_config.text_config.hidden_size)
129
+ self.softmax = nn.Softmax(dim=-1)
130
+
131
+ self.sigmoid = nn.Sigmoid()
132
+
133
+ self.loss_fct = CrossEntropyLoss()
134
+
135
+ self.init_weights()
136
+
137
+ def forward(self,
138
+ input_ids: torch.LongTensor = None,
139
+ image_index: torch.LongTensor = None,
140
+ audio_index: torch.LongTensor = None,
141
+ image_starts: torch.int = None,
142
+ image_ends: torch.int = None,
143
+ audio_starts: torch.int = None,
144
+ audio_ends: torch.int = None,
145
+ images: torch.FloatTensor = None,
146
+ audios: torch.FloatTensor = None,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ position_ids: Optional[torch.LongTensor] = None,
149
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
150
+ inputs_embeds: Optional[torch.FloatTensor] = None,
151
+ labels: Optional[torch.LongTensor] = None,
152
+ output_attentions: Optional[bool] = None,
153
+ output_hidden_states: Optional[bool] = None,
154
+ use_cache: Optional[bool] = None,
155
+ return_dict: Optional[bool] = None):
156
+
157
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
158
+
159
+ images = images.type(self.image_encoder.dtype) if images is not None else None
160
+ audios = audios.type(self.audio_encoder.dtype) if audios is not None else None
161
+
162
+ model_inputs = self.prepare_inputs_for_generation(
163
+ input_ids=input_ids,
164
+ image_index=image_index,
165
+ audio_index=audio_index,
166
+ image_starts=image_starts,
167
+ image_ends=image_ends,
168
+ audio_starts=audio_starts,
169
+ audio_ends=audio_ends,
170
+ images=images,
171
+ audios=audios,
172
+ attention_mask=attention_mask,
173
+ labels=labels)
174
+
175
+ outputs = self.llm(
176
+ inputs_embeds=model_inputs['inputs_embeds'],
177
+ attention_mask=model_inputs['attention_mask'],
178
+ labels=model_inputs['labels'],
179
+ return_dict=return_dict)
180
+
181
+ return outputs
182
+
183
+ def prepare_inputs_for_generation(
184
+ self,
185
+ input_ids,
186
+ past_key_values=None,
187
+ inputs_embeds=None,
188
+ images=None,
189
+ audios=None,
190
+ audio_starts=None,
191
+ audio_ends=None,
192
+ image_starts=None,
193
+ image_ends=None,
194
+ attention_mask=None,
195
+ labels=None,
196
+ audio_index=None,
197
+ image_index=None,
198
+ **kwargs):
199
+
200
+ image_features = self.encode_image(
201
+ images) if images is not None else None
202
+ audio_features = self.encode_audio(
203
+ audios) if audios is not None else None
204
+ embed_tokens = self.llm.model.embed_tokens
205
+ text_embeddings = embed_tokens(input_ids)
206
+
207
+ token_embeddings = embed_tokens.weight.unsqueeze(0).repeat(
208
+ text_embeddings.size(0), 1, 1).transpose(0, 1)
209
+
210
+ ingore_num = 0
211
+
212
+ if audio_features is not None:
213
+
214
+ audio_starts = embed_tokens(audio_starts).unsqueeze(1)
215
+ audio_ends = embed_tokens(audio_ends).unsqueeze(1)
216
+
217
+ audio_features = self.project_audio(
218
+ audio_features.transpose(
219
+ 1, 2).contiguous()).transpose(
220
+ 1, 2).contiguous()
221
+
222
+ audio_features = self.transform_audio_to_hidden(audio_features)
223
+
224
+ max_count = most_frequent_element(audio_index)
225
+
226
+ seq_img = audio_features.shape[1]
227
+ dim = token_embeddings.shape[2]
228
+
229
+ new_audio = torch.zeros(
230
+ (token_embeddings.shape[1],
231
+ seq_img * max_count,
232
+ dim),
233
+ device=token_embeddings.device,
234
+ dtype=token_embeddings.dtype)
235
+ current_dim = 0
236
+ for no, index in enumerate(audio_index):
237
+ if no > 0 and audio_index[no - 1] == index:
238
+ current_dim += 1
239
+ else:
240
+ current_dim = 0
241
+ new_audio[index, current_dim *
242
+ seq_img: (current_dim + 1) * seq_img] = audio_features[no]
243
+ last_index = audio_index[0]
244
+
245
+ audio_features = self.audio_align_attention(
246
+ new_audio.transpose(
247
+ 0,
248
+ 1),
249
+ token_embeddings,
250
+ token_embeddings)[0].transpose(
251
+ 0,
252
+ 1).contiguous()
253
+
254
+ # audio_features = add_positional_encoding(audio_features)
255
+
256
+ audio_inputs = torch.cat(
257
+ [torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)
258
+
259
+ text_embeddings = torch.cat(
260
+ [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
261
+ dim=1)
262
+
263
+ ingore_num += (audio_inputs.size(1))
264
+
265
+ if image_features is not None:
266
+
267
+ image_starts = embed_tokens(image_starts).unsqueeze(1)
268
+ image_ends = embed_tokens(image_ends).unsqueeze(1)
269
+
270
+ image_features = self.project_image(
271
+ image_features.transpose(
272
+ 1, 2).contiguous()).transpose(
273
+ 1, 2).contiguous()
274
+
275
+ image_features = self.transform_image_to_hidden(image_features)
276
+
277
+ max_count = most_frequent_element(image_index)
278
+
279
+ seq_img = image_features.shape[1]
280
+ dim = token_embeddings.shape[2]
281
+
282
+ new_img = torch.zeros(
283
+ (token_embeddings.shape[1],
284
+ seq_img * max_count,
285
+ dim),
286
+ device=token_embeddings.device,
287
+ dtype=token_embeddings.dtype)
288
+
289
+ current_dim = 0
290
+ for no, index in enumerate(image_index):
291
+ if no > 0 and image_index[no - 1] == index:
292
+ current_dim += 1
293
+ else:
294
+ current_dim = 0
295
+ new_img[index, current_dim *
296
+ seq_img: (current_dim + 1) * seq_img] = image_features[no]
297
+ last_index = image_index[0]
298
+
299
+ image_features = self.image_align_attention(
300
+ new_img.transpose(
301
+ 0,
302
+ 1),
303
+ token_embeddings,
304
+ token_embeddings)[0].transpose(
305
+ 0,
306
+ 1).contiguous()
307
+
308
+ # image_features = add_positional_encoding(image_features)
309
+
310
+ image_inputs = torch.cat(
311
+ [torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)
312
+
313
+ text_embeddings = torch.cat(
314
+ [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1),
315
+ text_embeddings[:, 1:, :]], dim=1)
316
+
317
+ ingore_num += (image_inputs.size(1))
318
+
319
+ if attention_mask is not None:
320
+ attentionmask = torch.tensor([1]*ingore_num*text_embeddings.size(0),
321
+ device=text_embeddings.device).view(text_embeddings.size(0), -1)
322
+ attentionmask = torch.cat([attentionmask, attention_mask], dim=1)
323
+ else:
324
+ attention_mask = None
325
+
326
+ if labels is not None:
327
+ labels_ = torch.tensor([-100]*ingore_num*text_embeddings.size(0),
328
+ device=text_embeddings.device).view(text_embeddings.size(0), -1)
329
+ labels = torch.cat([labels_, labels], dim=1)
330
+ else:
331
+ labels = None
332
+
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and past_key_values is None:
335
+ model_inputs = {"inputs_embeds": inputs_embeds}
336
+ else:
337
+ model_inputs = {"input_ids": input_ids}
338
+
339
+ model_inputs.update(
340
+ {
341
+ "inputs_embeds": text_embeddings,
342
+ "use_cache": kwargs.get("use_cache"),
343
+ "attention_mask": attentionmask,
344
+ "labels": labels,
345
+ }
346
+ )
347
+ return model_inputs
348
+
349
+ def encode_audio(self, audios):
350
+ audio_features = self.audio_encoder.encoder(audios)
351
+ return audio_features[0]
352
+
353
+ def encode_image(self, images):
354
+
355
+ image_features = self.visual_projection(
356
+ self.image_encoder.vision_model(images)[0])[:, 1:, :]
357
+
358
+ return image_features
359
+
360
+
361
+ def create_positional_encoding(L, h):
362
+ # Create a tensor to store the position encoding
363
+ position_encoding = torch.zeros(L, h)
364
+
365
+ # Fill the position encoding tensor
366
+ for pos in range(L):
367
+ for i in range(0, h, 2):
368
+ div_term = torch.exp(torch.tensor(-(math.log(10000.0) / h * (2 * i))))
369
+ position_encoding[pos, i] = torch.sin(pos * div_term)
370
+ position_encoding[pos, i + 1] = torch.cos(pos * div_term)
371
+
372
+ return position_encoding
373
+
374
+
375
+ def add_positional_encoding(tensor):
376
+ N, L, h = tensor.size() # batch size, sequence length, and feature dimension
377
+
378
+ # Create position embedding tensor
379
+ position_embedding = create_positional_encoding(L, h).to(tensor.device).to(tensor.dtype)
380
+
381
+ # Expand position embedding to match input tensor dimensions
382
+ position_embedding = position_embedding.unsqueeze(0).expand(N, -1, -1)
383
+
384
+ # Add position embedding to the input tensor
385
+ return tensor + position_embedding