alex-ht commited on
Commit
856d1dd
1 Parent(s): 3889df6
Files changed (3) hide show
  1. config.json +4 -3
  2. mllama_audio_model.py +1 -34
  3. modeling_llama3.py +9 -398
config.json CHANGED
@@ -4,7 +4,8 @@
4
  "Llama3ForConditionalGeneration"
5
  ],
6
  "audio_config": {
7
- "_attn_implementation_autoset": true,
 
8
  "_name_or_path": "",
9
  "activation_dropout": 0.0,
10
  "adapter_act": "relu",
@@ -143,8 +144,8 @@
143
  "auto_map": {
144
  "AutoConfig": "AlexHung29629/test_mllama_11B_v3--configuration_llama3.Llama3Config",
145
  "AutoModel": "AlexHung29629/test_mllama_11B_v3--modeling_llama3.Llama3ForConditionalGeneration",
146
- "AutoModelForCausalLM": "AlexHung29629/test_mllama_11B_v3--modeling_llama3.Llama3ForCausalLM",
147
- "AutoProcessor": "AlexHung29629/test_mllama_11B_v3--processing_mllama.MllamaProcessor"
148
  },
149
  "image_token_index": 128256,
150
  "model_type": "llama3",
 
4
  "Llama3ForConditionalGeneration"
5
  ],
6
  "audio_config": {
7
+ "_attn_implementation_autoset": false,
8
+ "_attn_implementation": "eager",
9
  "_name_or_path": "",
10
  "activation_dropout": 0.0,
11
  "adapter_act": "relu",
 
144
  "auto_map": {
145
  "AutoConfig": "AlexHung29629/test_mllama_11B_v3--configuration_llama3.Llama3Config",
146
  "AutoModel": "AlexHung29629/test_mllama_11B_v3--modeling_llama3.Llama3ForConditionalGeneration",
147
+ "AutoProcessor": "AlexHung29629/test_mllama_11B_v3--processing_mllama.MllamaProcessor",
148
+ "AutoFeatureExtractor": "AlexHung29629/test_mllama_11B_v3--audio_processing_mllama.MllamaAudioFeatureExtractor"
149
  },
150
  "image_token_index": 128256,
151
  "model_type": "llama3",
mllama_audio_model.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Optional, Tuple, Union
 
2
  import torch
3
  from torch import nn
4
  from transformers.modeling_outputs import BaseModelOutput
@@ -14,43 +15,9 @@ class Llama3Embedding(Wav2Vec2BertPreTrainedModel):
14
  assert config.output_hidden_size == text_config.hidden_size
15
  self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
16
  self.audio_embedding = Wav2Vec2BertModel(config)
17
- #assert self.text_embeddings.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
18
  self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
19
  self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
20
  self.text_config = text_config
21
-
22
- def _init_weights(self, module):
23
- std = self.text_config.initializer_range
24
- """Initialize the weights"""
25
- if isinstance(module, Wav2Vec2BertSelfAttention):
26
- if hasattr(module, "pos_bias_u"):
27
- nn.init.xavier_uniform_(module.pos_bias_u)
28
- if hasattr(module, "pos_bias_v"):
29
- nn.init.xavier_uniform_(module.pos_bias_v)
30
- elif isinstance(module, Wav2Vec2BertFeatureProjection):
31
- k = math.sqrt(1 / module.projection.in_features)
32
- nn.init.uniform_(module.projection.weight, a=-k, b=k)
33
- nn.init.uniform_(module.projection.bias, a=-k, b=k)
34
- elif isinstance(module, nn.Linear):
35
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
36
-
37
- if module.bias is not None:
38
- module.bias.data.zero_()
39
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
40
- module.bias.data.zero_()
41
- module.weight.data.fill_(1.0)
42
- elif isinstance(module, nn.Conv1d):
43
- nn.init.kaiming_normal_(module.weight)
44
-
45
- if module.bias is not None:
46
- k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
47
- nn.init.uniform_(module.bias, a=-k, b=k)
48
- elif isinstance(module, nn.Embedding):
49
- module.weight.data.normal_(mean=0.0, std=std)
50
- if module.padding_idx is not None:
51
- module.weight.data[module.padding_idx].zero_()
52
- elif isinstance(module, nn.Parameter):
53
- module.data.normal_(mean=0.0, std=std)
54
 
55
  def forward(
56
  self,
 
1
  from typing import Optional, Tuple, Union
2
+ import math
3
  import torch
4
  from torch import nn
5
  from transformers.modeling_outputs import BaseModelOutput
 
15
  assert config.output_hidden_size == text_config.hidden_size
16
  self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
17
  self.audio_embedding = Wav2Vec2BertModel(config)
 
18
  self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
19
  self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
20
  self.text_config = text_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def forward(
23
  self,
modeling_llama3.py CHANGED
@@ -6,405 +6,17 @@ import torch.utils.checkpoint
6
  from torch import nn
7
 
8
  import transformers
9
- from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, Wav2Vec2BertConfig, AutoModel, AutoModelForCausalLM
10
- from transformers.cache_utils import Cache, StaticCache
11
  from transformers.generation import GenerationMixin
12
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
- from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
14
  from transformers.utils import logging
15
- from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
16
- from transformers.models.mllama.configuration_mllama import MllamaTextConfig
17
  from .configuration_llama3 import Llama3Config
18
  from .mllama_audio_model import Llama3Embedding
19
 
20
 
21
  logger = logging.get_logger(__name__)
22
 
23
- class Llama3PreTrainedModel(MllamaPreTrainedModel):
24
- config_class = Llama3Config
25
- base_model_prefix = "model"
26
-
27
- class Llama3TextModel(MllamaPreTrainedModel):
28
- config_class = MllamaTextConfig
29
- base_model_prefix = "language_model.model"
30
-
31
- def __init__(self, config: MllamaTextConfig):
32
- super().__init__(config)
33
- self.padding_idx = config.pad_token_id
34
- self.vocab_size = config.vocab_size
35
- #self.embed_tokens = Llama3Embedding(audio_config, config)
36
- self.cross_attention_layers = config.cross_attention_layers
37
-
38
- layers = []
39
- for layer_idx in range(config.num_hidden_layers):
40
- if layer_idx in self.cross_attention_layers:
41
- layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
42
- else:
43
- layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
44
-
45
- self.layers = nn.ModuleList(layers)
46
- self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
47
- self.rotary_emb = MllamaRotaryEmbedding(config=config)
48
- self.gradient_checkpointing = False
49
- self.post_init()
50
-
51
- def get_input_embeddings(self):
52
- #return self.embed_tokens.text_embeddings
53
- return None
54
-
55
- def set_input_embeddings(self, value):
56
- #self.embed_tokens.text_embeddings = value
57
- pass
58
-
59
- def forward(
60
- self,
61
- #input_ids: Optional[torch.LongTensor] = None,
62
- #audio_features: Optional[torch.Tensor] = None,
63
- attention_mask: Optional[torch.Tensor] = None,
64
- position_ids: Optional[torch.LongTensor] = None,
65
- cross_attention_states: Optional[torch.FloatTensor] = None,
66
- cross_attention_mask: Optional[torch.Tensor] = None,
67
- full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
68
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
69
- inputs_embeds: Optional[torch.FloatTensor] = None,
70
- use_cache: Optional[bool] = None,
71
- output_attentions: Optional[bool] = None,
72
- output_hidden_states: Optional[bool] = None,
73
- return_dict: Optional[bool] = None,
74
- cache_position: Optional[torch.LongTensor] = None,
75
- ) -> Union[Tuple, BaseModelOutputWithPast]:
76
- """
77
-
78
- Returns:
79
-
80
- Example:
81
-
82
- ```python
83
- >>> from transformers import AutoProcessor, MllamaTextModel
84
-
85
- >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
86
- >>> model = MllamaTextModel.from_pretrained(checkpoint)
87
- >>> processor = AutoProcessor.from_pretrained(checkpoint)
88
-
89
- >>> text = "<|image|>If I had to write a haiku for this one"
90
- >>> inputs = processor(text=text, return_tensors="pt")
91
-
92
- >>> output = model(**inputs)
93
-
94
- >>> print(output.last_hidden_state.shape)
95
- torch.Size([1, 13, 4096])
96
- ```
97
- """
98
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
99
- output_hidden_states = (
100
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
101
- )
102
- use_cache = use_cache if use_cache is not None else self.config.use_cache
103
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
-
105
- #if (input_ids is None) ^ (inputs_embeds is not None):
106
- # raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
107
-
108
- if self.gradient_checkpointing and self.training and use_cache:
109
- logger.warning_once(
110
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
111
- )
112
- use_cache = False
113
-
114
- #if inputs_embeds is None:
115
- # inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
116
-
117
-
118
- hidden_states = inputs_embeds
119
-
120
- if cache_position is None:
121
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
122
- cache_position = torch.arange(
123
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
124
- )
125
- if position_ids is None:
126
- position_ids = cache_position.unsqueeze(0)
127
-
128
- causal_mask = self._update_causal_mask(
129
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
130
- )
131
-
132
- # create position embeddings to be shared across the decoder layers
133
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
134
-
135
- # decoder layers
136
- all_hidden_states = () if output_hidden_states else None
137
- all_self_attns = () if output_attentions else None
138
- next_decoder_cache = None
139
-
140
- for idx, decoder_layer in enumerate(self.layers):
141
- if output_hidden_states:
142
- all_hidden_states += (hidden_states,)
143
-
144
- # For text-only path we should skip cross attention layers.
145
- # Let's check if the layer is cross attention layer and if we have cross attention states
146
- # or cached cross attention states.
147
- is_cross_attention_layer = idx in self.cross_attention_layers
148
- is_cross_attention_cache_empty = past_key_values is None or (
149
- past_key_values is not None and past_key_values.get_seq_length(idx) == 0
150
- )
151
-
152
- if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
153
- continue
154
-
155
- if self.gradient_checkpointing and self.training:
156
- layer_outputs = self._gradient_checkpointing_func(
157
- decoder_layer.__call__,
158
- hidden_states,
159
- cross_attention_states,
160
- cross_attention_mask,
161
- causal_mask,
162
- full_text_row_masked_out_mask,
163
- position_ids,
164
- past_key_values,
165
- output_attentions,
166
- use_cache,
167
- cache_position,
168
- position_embeddings,
169
- )
170
- else:
171
- layer_outputs = decoder_layer(
172
- hidden_states,
173
- cross_attention_states=cross_attention_states,
174
- cross_attention_mask=cross_attention_mask,
175
- attention_mask=causal_mask,
176
- full_text_row_masked_out_mask=full_text_row_masked_out_mask,
177
- position_ids=position_ids,
178
- past_key_value=past_key_values,
179
- output_attentions=output_attentions,
180
- use_cache=use_cache,
181
- cache_position=cache_position,
182
- position_embeddings=position_embeddings,
183
- )
184
-
185
- hidden_states = layer_outputs[0]
186
-
187
- if use_cache:
188
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
189
-
190
- if output_attentions:
191
- all_self_attns += (layer_outputs[1],)
192
-
193
- hidden_states = self.norm(hidden_states)
194
-
195
- # add hidden states from the last decoder layer
196
- if output_hidden_states:
197
- all_hidden_states += (hidden_states,)
198
-
199
- next_cache = next_decoder_cache if use_cache else None
200
-
201
- if not return_dict:
202
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
203
- return BaseModelOutputWithPast(
204
- last_hidden_state=hidden_states,
205
- past_key_values=next_cache,
206
- hidden_states=all_hidden_states,
207
- attentions=all_self_attns,
208
- )
209
-
210
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
211
- def _update_causal_mask(
212
- self,
213
- attention_mask: torch.Tensor,
214
- input_tensor: torch.Tensor,
215
- cache_position: torch.Tensor,
216
- past_key_values: Cache,
217
- output_attentions: bool,
218
- ):
219
- if self.config._attn_implementation == "flash_attention_2":
220
- if attention_mask is not None and 0.0 in attention_mask:
221
- return attention_mask
222
- return None
223
-
224
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
225
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
226
- # to infer the attention mask.
227
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
228
- using_static_cache = isinstance(past_key_values, StaticCache)
229
-
230
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
231
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
232
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
233
- attention_mask,
234
- inputs_embeds=input_tensor,
235
- past_key_values_length=past_seen_tokens,
236
- is_training=self.training,
237
- ):
238
- return None
239
-
240
- dtype, device = input_tensor.dtype, input_tensor.device
241
- sequence_length = input_tensor.shape[1]
242
- if using_static_cache:
243
- target_length = past_key_values.get_max_cache_shape()
244
- else:
245
- target_length = (
246
- attention_mask.shape[-1]
247
- if isinstance(attention_mask, torch.Tensor)
248
- else past_seen_tokens + sequence_length + 1
249
- )
250
-
251
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
252
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
253
- attention_mask,
254
- sequence_length=sequence_length,
255
- target_length=target_length,
256
- dtype=dtype,
257
- device=device,
258
- cache_position=cache_position,
259
- batch_size=input_tensor.shape[0],
260
- )
261
-
262
- if (
263
- self.config._attn_implementation == "sdpa"
264
- and attention_mask is not None
265
- and attention_mask.device.type == "cuda"
266
- and not output_attentions
267
- ):
268
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
269
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
270
- # Details: https://github.com/pytorch/pytorch/issues/110213
271
- min_dtype = torch.finfo(dtype).min
272
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
273
-
274
- return causal_mask
275
-
276
- @staticmethod
277
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
278
- def _prepare_4d_causal_attention_mask_with_cache_position(
279
- attention_mask: torch.Tensor,
280
- sequence_length: int,
281
- target_length: int,
282
- dtype: torch.dtype,
283
- device: torch.device,
284
- cache_position: torch.Tensor,
285
- batch_size: int,
286
- **kwargs,
287
- ):
288
- if attention_mask is not None and attention_mask.dim() == 4:
289
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
290
- causal_mask = attention_mask
291
- else:
292
- min_dtype = torch.finfo(dtype).min
293
- causal_mask = torch.full(
294
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
295
- )
296
- if sequence_length != 1:
297
- causal_mask = torch.triu(causal_mask, diagonal=1)
298
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
299
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
300
- if attention_mask is not None:
301
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
302
- mask_length = attention_mask.shape[-1]
303
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
304
- padding_mask = padding_mask == 0
305
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
306
- padding_mask, min_dtype
307
- )
308
-
309
- return causal_mask
310
-
311
- class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
312
- config_class = MllamaTextConfig
313
- base_model_prefix = "model"
314
- #_tied_weights_keys = ["lm_head.weight"]
315
-
316
- def __init__(self, config: MllamaTextConfig):
317
- super().__init__(config.get_text_config())
318
- self.text_config = config.get_text_config()
319
- self.vocab_size = self.text_config.vocab_size
320
- self.model = Llama3TextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
321
- self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
322
-
323
- self.post_init()
324
-
325
- def get_input_embeddings(self):
326
- #return self.model.embed_tokens.text_embeddings
327
- return None
328
-
329
- def set_input_embeddings(self, value):
330
- #self.model.embed_tokens.text_embeddings = value
331
- pass
332
-
333
- def get_output_embeddings(self):
334
- return self.lm_head
335
-
336
- def set_output_embeddings(self, new_embeddings):
337
- self.lm_head = new_embeddings
338
-
339
- def set_decoder(self, decoder):
340
- self.model = decoder
341
-
342
- def get_decoder(self):
343
- return self.model
344
-
345
- def forward(
346
- self,
347
- #input_ids: torch.LongTensor = None,
348
- attention_mask: Optional[torch.Tensor] = None,
349
- position_ids: Optional[torch.LongTensor] = None,
350
- cross_attention_states: Optional[torch.LongTensor] = None,
351
- cross_attention_mask: Optional[torch.LongTensor] = None,
352
- full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
353
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
354
- inputs_embeds: Optional[torch.FloatTensor] = None,
355
- labels: Optional[torch.LongTensor] = None,
356
- use_cache: Optional[bool] = None,
357
- output_attentions: Optional[bool] = None,
358
- output_hidden_states: Optional[bool] = None,
359
- return_dict: Optional[bool] = None,
360
- cache_position: Optional[torch.LongTensor] = None,
361
- num_logits_to_keep: int = 0,
362
- **loss_kwargs,
363
- ) -> Union[Tuple, CausalLMOutputWithPast]:
364
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
365
- output_hidden_states = (
366
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
367
- )
368
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
-
370
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
371
- outputs = self.model(
372
- #input_ids=input_ids,
373
- cross_attention_states=cross_attention_states,
374
- attention_mask=attention_mask,
375
- position_ids=position_ids,
376
- cross_attention_mask=cross_attention_mask,
377
- full_text_row_masked_out_mask=full_text_row_masked_out_mask,
378
- past_key_values=past_key_values,
379
- inputs_embeds=inputs_embeds,
380
- use_cache=use_cache,
381
- output_attentions=output_attentions,
382
- output_hidden_states=output_hidden_states,
383
- return_dict=return_dict,
384
- cache_position=cache_position,
385
- )
386
-
387
- hidden_states = outputs[0]
388
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
389
-
390
- loss = None
391
- if labels is not None:
392
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
393
-
394
- if not return_dict:
395
- output = (logits,) + outputs[1:]
396
- return (loss,) + output if loss is not None else output
397
-
398
- return CausalLMOutputWithPast(
399
- loss=loss,
400
- logits=logits,
401
- past_key_values=outputs.past_key_values,
402
- hidden_states=outputs.hidden_states,
403
- attentions=outputs.attentions,
404
- )
405
-
406
- #AutoModelForCausalLM.register(MllamaTextConfig, Llama3ForCausalLM)
407
- #transformers.Llama3ForCausalLM = Llama3ForCausalLM
408
 
409
  class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
410
  config_class = Llama3Config
@@ -421,7 +33,6 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
421
 
422
  self.vision_model = MllamaVisionModel._from_config(config.vision_config)
423
  self.language_model = MllamaForCausalLM._from_config(config.text_config)
424
- self.language_model.get_input_embeddings().weight.required_grad = False
425
  self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
426
  self.multi_modal_projector = nn.Linear(
427
  config.vision_config.vision_output_dim,
@@ -431,10 +42,10 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
431
  self.post_init()
432
 
433
  def get_input_embeddings(self):
434
- return self.language_model.get_input_embeddings()
435
 
436
  def set_input_embeddings(self, value):
437
- self.language_model.set_input_embeddings(value)
438
 
439
  def get_output_embeddings(self):
440
  return self.language_model.get_output_embeddings()
@@ -565,8 +176,8 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
565
  inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
566
 
567
  outputs = self.language_model(
568
- #input_ids=input_ids,
569
- #audio_features=audio_features,
570
  attention_mask=attention_mask,
571
  position_ids=position_ids,
572
  cross_attention_states=cross_attention_states,
@@ -670,5 +281,5 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
670
  )
671
  return model_kwargs
672
 
673
- #AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
674
- #transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration
 
6
  from torch import nn
7
 
8
  import transformers
9
+ from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel
 
10
  from transformers.generation import GenerationMixin
11
+ from transformers.modeling_outputs import CausalLMOutputWithPast
 
12
  from transformers.utils import logging
13
+ from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
 
14
  from .configuration_llama3 import Llama3Config
15
  from .mllama_audio_model import Llama3Embedding
16
 
17
 
18
  logger = logging.get_logger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
22
  config_class = Llama3Config
 
33
 
34
  self.vision_model = MllamaVisionModel._from_config(config.vision_config)
35
  self.language_model = MllamaForCausalLM._from_config(config.text_config)
 
36
  self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
37
  self.multi_modal_projector = nn.Linear(
38
  config.vision_config.vision_output_dim,
 
42
  self.post_init()
43
 
44
  def get_input_embeddings(self):
45
+ return self.embed_tokens.text_embeddings
46
 
47
  def set_input_embeddings(self, value):
48
+ self.embed_tokens.text_embeddings = value
49
 
50
  def get_output_embeddings(self):
51
  return self.language_model.get_output_embeddings()
 
176
  inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
177
 
178
  outputs = self.language_model(
179
+ input_ids=None,
180
+ audio_features=None,
181
  attention_mask=attention_mask,
182
  position_ids=position_ids,
183
  cross_attention_states=cross_attention_states,
 
281
  )
282
  return model_kwargs
283
 
284
+ AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
285
+ transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration