aisyahhrazak commited on
Commit
f8637f3
1 Parent(s): e564d99

Upload MM_LLMs

Browse files
Files changed (3) hide show
  1. config.json +5 -1
  2. model.safetensors +2 -2
  3. modeling.py +37 -31
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "multimodal-tinyllama-whisper-small-siglip/checkpoint-2500",
3
  "architectures": [
4
  "MM_LLMs"
5
  ],
@@ -206,6 +206,10 @@
206
  },
207
  "audio_conv_kernel": 240,
208
  "audio_conv_stride": 220,
 
 
 
 
209
  "hidden_size": 2048,
210
  "image_config": {
211
  "_name_or_path": "google/siglip-base-patch16-224",
 
1
  {
2
+ "_name_or_path": "multimodal-tinyllama-whisper-small-siglip/checkpoint-15500",
3
  "architectures": [
4
  "MM_LLMs"
5
  ],
 
206
  },
207
  "audio_conv_kernel": 240,
208
  "audio_conv_stride": 220,
209
+ "auto_map": {
210
+ "AutoConfig": "modeling.MM_LLMs_Config",
211
+ "AutoModel": "modeling.MM_LLMs"
212
+ },
213
  "hidden_size": 2048,
214
  "image_config": {
215
  "_name_or_path": "google/siglip-base-patch16-224",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9537a5b0cbb8c353c742959d43e527634e05a97ae06aed7ac23f668a1c06d924
3
- size 3509162622
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd41c7ff4ba5d41aa0fb55a38f382e50d72f1da8d2521ff30c40d4e3a5a6bfb9
3
+ size 3504434068
modeling.py CHANGED
@@ -83,21 +83,16 @@ class MM_LLMs(PreTrainedModel):
83
  attn_dropout = 0.1
84
  is_add_bias_kv = True
85
  is_add_zero_attn = True
86
- self.temporal_self_attention = nn.MultiheadAttention(
87
- config.image_config.text_config.hidden_size,
88
- config.attention_heads,
89
- dropout=attn_dropout,
90
- add_bias_kv=is_add_bias_kv,
91
- add_zero_attn=is_add_zero_attn)
92
 
 
93
  self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
94
- config.attention_heads * 2,
95
  dropout=attn_dropout,
96
  add_bias_kv=is_add_bias_kv,
97
  add_zero_attn=is_add_zero_attn)
98
 
99
  self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
100
- config.attention_heads * 2,
101
  dropout=attn_dropout,
102
  add_bias_kv=is_add_bias_kv,
103
  add_zero_attn=is_add_zero_attn)
@@ -123,12 +118,7 @@ class MM_LLMs(PreTrainedModel):
123
  self.config.image_config.text_config.hidden_size,
124
  bias=False)
125
 
126
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
127
-
128
  self.layer_norm = nn.LayerNorm(config.image_config.text_config.hidden_size)
129
- self.softmax = nn.Softmax(dim=-1)
130
-
131
- self.sigmoid = nn.Sigmoid()
132
 
133
  self.loss_fct = CrossEntropyLoss()
134
 
@@ -188,9 +178,7 @@ class MM_LLMs(PreTrainedModel):
188
  images=None,
189
  audios=None,
190
  audio_starts=None,
191
- audio_ends=None,
192
  image_starts=None,
193
- image_ends=None,
194
  attention_mask=None,
195
  labels=None,
196
  audio_index=None,
@@ -212,7 +200,6 @@ class MM_LLMs(PreTrainedModel):
212
  if audio_features is not None:
213
 
214
  audio_starts = embed_tokens(audio_starts).unsqueeze(1)
215
- audio_ends = embed_tokens(audio_ends).unsqueeze(1)
216
 
217
  audio_features = self.project_audio(
218
  audio_features.transpose(
@@ -232,29 +219,39 @@ class MM_LLMs(PreTrainedModel):
232
  dim),
233
  device=token_embeddings.device,
234
  dtype=token_embeddings.dtype)
 
 
 
 
 
 
 
 
235
  current_dim = 0
236
  for no, index in enumerate(audio_index):
237
  if no > 0 and audio_index[no - 1] == index:
238
  current_dim += 1
239
  else:
240
  current_dim = 0
241
- new_audio[index, current_dim *
242
- seq_img: (current_dim + 1) * seq_img] = audio_features[no]
243
- last_index = audio_index[0]
 
 
 
244
 
245
  audio_features = self.audio_align_attention(
246
  new_audio.transpose(
247
  0,
248
  1),
249
  token_embeddings,
250
- token_embeddings)[0].transpose(
 
 
251
  0,
252
  1).contiguous()
253
 
254
- # audio_features = add_positional_encoding(audio_features)
255
-
256
- audio_inputs = torch.cat(
257
- [torch.cat([audio_starts, audio_features], dim=1), audio_ends], dim=1)
258
 
259
  text_embeddings = torch.cat(
260
  [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
@@ -265,7 +262,6 @@ class MM_LLMs(PreTrainedModel):
265
  if image_features is not None:
266
 
267
  image_starts = embed_tokens(image_starts).unsqueeze(1)
268
- image_ends = embed_tokens(image_ends).unsqueeze(1)
269
 
270
  image_features = self.project_image(
271
  image_features.transpose(
@@ -286,6 +282,16 @@ class MM_LLMs(PreTrainedModel):
286
  device=token_embeddings.device,
287
  dtype=token_embeddings.dtype)
288
 
 
 
 
 
 
 
 
 
 
 
289
  current_dim = 0
290
  for no, index in enumerate(image_index):
291
  if no > 0 and image_index[no - 1] == index:
@@ -294,21 +300,21 @@ class MM_LLMs(PreTrainedModel):
294
  current_dim = 0
295
  new_img[index, current_dim *
296
  seq_img: (current_dim + 1) * seq_img] = image_features[no]
297
- last_index = image_index[0]
 
298
 
299
  image_features = self.image_align_attention(
300
  new_img.transpose(
301
  0,
302
  1),
303
  token_embeddings,
304
- token_embeddings)[0].transpose(
 
 
305
  0,
306
  1).contiguous()
307
 
308
- # image_features = add_positional_encoding(image_features)
309
-
310
- image_inputs = torch.cat(
311
- [torch.cat([image_starts, image_features], dim=1), image_ends], dim=1)
312
 
313
  text_embeddings = torch.cat(
314
  [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1),
 
83
  attn_dropout = 0.1
84
  is_add_bias_kv = True
85
  is_add_zero_attn = True
 
 
 
 
 
 
86
 
87
+ self.num_heads = config.attention_heads * 2
88
  self.audio_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
89
+ self.num_heads,
90
  dropout=attn_dropout,
91
  add_bias_kv=is_add_bias_kv,
92
  add_zero_attn=is_add_zero_attn)
93
 
94
  self.image_align_attention = nn.MultiheadAttention(config.llm_config.hidden_size,
95
+ self.num_heads,
96
  dropout=attn_dropout,
97
  add_bias_kv=is_add_bias_kv,
98
  add_zero_attn=is_add_zero_attn)
 
118
  self.config.image_config.text_config.hidden_size,
119
  bias=False)
120
 
 
 
121
  self.layer_norm = nn.LayerNorm(config.image_config.text_config.hidden_size)
 
 
 
122
 
123
  self.loss_fct = CrossEntropyLoss()
124
 
 
178
  images=None,
179
  audios=None,
180
  audio_starts=None,
 
181
  image_starts=None,
 
182
  attention_mask=None,
183
  labels=None,
184
  audio_index=None,
 
200
  if audio_features is not None:
201
 
202
  audio_starts = embed_tokens(audio_starts).unsqueeze(1)
 
203
 
204
  audio_features = self.project_audio(
205
  audio_features.transpose(
 
219
  dim),
220
  device=token_embeddings.device,
221
  dtype=token_embeddings.dtype)
222
+ new_audio_mask = torch.ones(
223
+ (
224
+ token_embeddings.shape[1] * self.num_heads,
225
+ seq_img * max_count,
226
+ token_embeddings.shape[0]
227
+ ),
228
+ device=token_embeddings.device,
229
+ dtype=torch.bool)
230
  current_dim = 0
231
  for no, index in enumerate(audio_index):
232
  if no > 0 and audio_index[no - 1] == index:
233
  current_dim += 1
234
  else:
235
  current_dim = 0
236
+ new_audio[
237
+ index, current_dim *
238
+ seq_img: (current_dim + 1) * seq_img
239
+ ] = audio_features[no]
240
+ new_audio_mask[index * self.num_heads: (index + 1) * self.num_heads, current_dim *
241
+ seq_img: (current_dim + 1) * seq_img] = 0
242
 
243
  audio_features = self.audio_align_attention(
244
  new_audio.transpose(
245
  0,
246
  1),
247
  token_embeddings,
248
+ token_embeddings,
249
+ attn_mask=new_audio_mask
250
+ )[0].transpose(
251
  0,
252
  1).contiguous()
253
 
254
+ audio_inputs = torch.cat([audio_starts, audio_features], dim=1)
 
 
 
255
 
256
  text_embeddings = torch.cat(
257
  [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), audio_inputs], dim=1), text_embeddings[:, 1:, :]],
 
262
  if image_features is not None:
263
 
264
  image_starts = embed_tokens(image_starts).unsqueeze(1)
 
265
 
266
  image_features = self.project_image(
267
  image_features.transpose(
 
282
  device=token_embeddings.device,
283
  dtype=token_embeddings.dtype)
284
 
285
+ new_img_mask = torch.ones(
286
+ (
287
+ token_embeddings.shape[1] * self.num_heads,
288
+ seq_img * max_count,
289
+ token_embeddings.shape[0]
290
+ ),
291
+ device=token_embeddings.device,
292
+ dtype=torch.bool
293
+ )
294
+
295
  current_dim = 0
296
  for no, index in enumerate(image_index):
297
  if no > 0 and image_index[no - 1] == index:
 
300
  current_dim = 0
301
  new_img[index, current_dim *
302
  seq_img: (current_dim + 1) * seq_img] = image_features[no]
303
+ new_audio_mask[index * self.num_heads: (index + 1) * self.num_heads, current_dim *
304
+ seq_img: (current_dim + 1) * seq_img] = 0
305
 
306
  image_features = self.image_align_attention(
307
  new_img.transpose(
308
  0,
309
  1),
310
  token_embeddings,
311
+ token_embeddings,
312
+ attn_mask=new_img_mask,
313
+ )[0].transpose(
314
  0,
315
  1).contiguous()
316
 
317
+ image_inputs = torch.cat([image_starts, image_features], dim=1)
 
 
 
318
 
319
  text_embeddings = torch.cat(
320
  [torch.cat([text_embeddings[:, 0, :].unsqueeze(1), image_inputs], dim=1),