alex-ht
commited on
Commit
•
856d1dd
1
Parent(s):
3889df6
update
Browse files- config.json +4 -3
- mllama_audio_model.py +1 -34
- modeling_llama3.py +9 -398
config.json
CHANGED
@@ -4,7 +4,8 @@
|
|
4 |
"Llama3ForConditionalGeneration"
|
5 |
],
|
6 |
"audio_config": {
|
7 |
-
"_attn_implementation_autoset":
|
|
|
8 |
"_name_or_path": "",
|
9 |
"activation_dropout": 0.0,
|
10 |
"adapter_act": "relu",
|
@@ -143,8 +144,8 @@
|
|
143 |
"auto_map": {
|
144 |
"AutoConfig": "AlexHung29629/test_mllama_11B_v3--configuration_llama3.Llama3Config",
|
145 |
"AutoModel": "AlexHung29629/test_mllama_11B_v3--modeling_llama3.Llama3ForConditionalGeneration",
|
146 |
-
"
|
147 |
-
"
|
148 |
},
|
149 |
"image_token_index": 128256,
|
150 |
"model_type": "llama3",
|
|
|
4 |
"Llama3ForConditionalGeneration"
|
5 |
],
|
6 |
"audio_config": {
|
7 |
+
"_attn_implementation_autoset": false,
|
8 |
+
"_attn_implementation": "eager",
|
9 |
"_name_or_path": "",
|
10 |
"activation_dropout": 0.0,
|
11 |
"adapter_act": "relu",
|
|
|
144 |
"auto_map": {
|
145 |
"AutoConfig": "AlexHung29629/test_mllama_11B_v3--configuration_llama3.Llama3Config",
|
146 |
"AutoModel": "AlexHung29629/test_mllama_11B_v3--modeling_llama3.Llama3ForConditionalGeneration",
|
147 |
+
"AutoProcessor": "AlexHung29629/test_mllama_11B_v3--processing_mllama.MllamaProcessor",
|
148 |
+
"AutoFeatureExtractor": "AlexHung29629/test_mllama_11B_v3--audio_processing_mllama.MllamaAudioFeatureExtractor"
|
149 |
},
|
150 |
"image_token_index": 128256,
|
151 |
"model_type": "llama3",
|
mllama_audio_model.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from typing import Optional, Tuple, Union
|
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
from transformers.modeling_outputs import BaseModelOutput
|
@@ -14,43 +15,9 @@ class Llama3Embedding(Wav2Vec2BertPreTrainedModel):
|
|
14 |
assert config.output_hidden_size == text_config.hidden_size
|
15 |
self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
|
16 |
self.audio_embedding = Wav2Vec2BertModel(config)
|
17 |
-
#assert self.text_embeddings.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
|
18 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
19 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
20 |
self.text_config = text_config
|
21 |
-
|
22 |
-
def _init_weights(self, module):
|
23 |
-
std = self.text_config.initializer_range
|
24 |
-
"""Initialize the weights"""
|
25 |
-
if isinstance(module, Wav2Vec2BertSelfAttention):
|
26 |
-
if hasattr(module, "pos_bias_u"):
|
27 |
-
nn.init.xavier_uniform_(module.pos_bias_u)
|
28 |
-
if hasattr(module, "pos_bias_v"):
|
29 |
-
nn.init.xavier_uniform_(module.pos_bias_v)
|
30 |
-
elif isinstance(module, Wav2Vec2BertFeatureProjection):
|
31 |
-
k = math.sqrt(1 / module.projection.in_features)
|
32 |
-
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
33 |
-
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
34 |
-
elif isinstance(module, nn.Linear):
|
35 |
-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
36 |
-
|
37 |
-
if module.bias is not None:
|
38 |
-
module.bias.data.zero_()
|
39 |
-
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
40 |
-
module.bias.data.zero_()
|
41 |
-
module.weight.data.fill_(1.0)
|
42 |
-
elif isinstance(module, nn.Conv1d):
|
43 |
-
nn.init.kaiming_normal_(module.weight)
|
44 |
-
|
45 |
-
if module.bias is not None:
|
46 |
-
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
47 |
-
nn.init.uniform_(module.bias, a=-k, b=k)
|
48 |
-
elif isinstance(module, nn.Embedding):
|
49 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
50 |
-
if module.padding_idx is not None:
|
51 |
-
module.weight.data[module.padding_idx].zero_()
|
52 |
-
elif isinstance(module, nn.Parameter):
|
53 |
-
module.data.normal_(mean=0.0, std=std)
|
54 |
|
55 |
def forward(
|
56 |
self,
|
|
|
1 |
from typing import Optional, Tuple, Union
|
2 |
+
import math
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
from transformers.modeling_outputs import BaseModelOutput
|
|
|
15 |
assert config.output_hidden_size == text_config.hidden_size
|
16 |
self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
|
17 |
self.audio_embedding = Wav2Vec2BertModel(config)
|
|
|
18 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
19 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
20 |
self.text_config = text_config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def forward(
|
23 |
self,
|
modeling_llama3.py
CHANGED
@@ -6,405 +6,17 @@ import torch.utils.checkpoint
|
|
6 |
from torch import nn
|
7 |
|
8 |
import transformers
|
9 |
-
from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM,
|
10 |
-
from transformers.cache_utils import Cache, StaticCache
|
11 |
from transformers.generation import GenerationMixin
|
12 |
-
from transformers.
|
13 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
|
14 |
from transformers.utils import logging
|
15 |
-
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
|
16 |
-
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
17 |
from .configuration_llama3 import Llama3Config
|
18 |
from .mllama_audio_model import Llama3Embedding
|
19 |
|
20 |
|
21 |
logger = logging.get_logger(__name__)
|
22 |
|
23 |
-
class Llama3PreTrainedModel(MllamaPreTrainedModel):
|
24 |
-
config_class = Llama3Config
|
25 |
-
base_model_prefix = "model"
|
26 |
-
|
27 |
-
class Llama3TextModel(MllamaPreTrainedModel):
|
28 |
-
config_class = MllamaTextConfig
|
29 |
-
base_model_prefix = "language_model.model"
|
30 |
-
|
31 |
-
def __init__(self, config: MllamaTextConfig):
|
32 |
-
super().__init__(config)
|
33 |
-
self.padding_idx = config.pad_token_id
|
34 |
-
self.vocab_size = config.vocab_size
|
35 |
-
#self.embed_tokens = Llama3Embedding(audio_config, config)
|
36 |
-
self.cross_attention_layers = config.cross_attention_layers
|
37 |
-
|
38 |
-
layers = []
|
39 |
-
for layer_idx in range(config.num_hidden_layers):
|
40 |
-
if layer_idx in self.cross_attention_layers:
|
41 |
-
layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
|
42 |
-
else:
|
43 |
-
layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
|
44 |
-
|
45 |
-
self.layers = nn.ModuleList(layers)
|
46 |
-
self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
47 |
-
self.rotary_emb = MllamaRotaryEmbedding(config=config)
|
48 |
-
self.gradient_checkpointing = False
|
49 |
-
self.post_init()
|
50 |
-
|
51 |
-
def get_input_embeddings(self):
|
52 |
-
#return self.embed_tokens.text_embeddings
|
53 |
-
return None
|
54 |
-
|
55 |
-
def set_input_embeddings(self, value):
|
56 |
-
#self.embed_tokens.text_embeddings = value
|
57 |
-
pass
|
58 |
-
|
59 |
-
def forward(
|
60 |
-
self,
|
61 |
-
#input_ids: Optional[torch.LongTensor] = None,
|
62 |
-
#audio_features: Optional[torch.Tensor] = None,
|
63 |
-
attention_mask: Optional[torch.Tensor] = None,
|
64 |
-
position_ids: Optional[torch.LongTensor] = None,
|
65 |
-
cross_attention_states: Optional[torch.FloatTensor] = None,
|
66 |
-
cross_attention_mask: Optional[torch.Tensor] = None,
|
67 |
-
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
68 |
-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
69 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
70 |
-
use_cache: Optional[bool] = None,
|
71 |
-
output_attentions: Optional[bool] = None,
|
72 |
-
output_hidden_states: Optional[bool] = None,
|
73 |
-
return_dict: Optional[bool] = None,
|
74 |
-
cache_position: Optional[torch.LongTensor] = None,
|
75 |
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
76 |
-
"""
|
77 |
-
|
78 |
-
Returns:
|
79 |
-
|
80 |
-
Example:
|
81 |
-
|
82 |
-
```python
|
83 |
-
>>> from transformers import AutoProcessor, MllamaTextModel
|
84 |
-
|
85 |
-
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
86 |
-
>>> model = MllamaTextModel.from_pretrained(checkpoint)
|
87 |
-
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
88 |
-
|
89 |
-
>>> text = "<|image|>If I had to write a haiku for this one"
|
90 |
-
>>> inputs = processor(text=text, return_tensors="pt")
|
91 |
-
|
92 |
-
>>> output = model(**inputs)
|
93 |
-
|
94 |
-
>>> print(output.last_hidden_state.shape)
|
95 |
-
torch.Size([1, 13, 4096])
|
96 |
-
```
|
97 |
-
"""
|
98 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
99 |
-
output_hidden_states = (
|
100 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
101 |
-
)
|
102 |
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
103 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
104 |
-
|
105 |
-
#if (input_ids is None) ^ (inputs_embeds is not None):
|
106 |
-
# raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
107 |
-
|
108 |
-
if self.gradient_checkpointing and self.training and use_cache:
|
109 |
-
logger.warning_once(
|
110 |
-
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
111 |
-
)
|
112 |
-
use_cache = False
|
113 |
-
|
114 |
-
#if inputs_embeds is None:
|
115 |
-
# inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
116 |
-
|
117 |
-
|
118 |
-
hidden_states = inputs_embeds
|
119 |
-
|
120 |
-
if cache_position is None:
|
121 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
122 |
-
cache_position = torch.arange(
|
123 |
-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
124 |
-
)
|
125 |
-
if position_ids is None:
|
126 |
-
position_ids = cache_position.unsqueeze(0)
|
127 |
-
|
128 |
-
causal_mask = self._update_causal_mask(
|
129 |
-
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
130 |
-
)
|
131 |
-
|
132 |
-
# create position embeddings to be shared across the decoder layers
|
133 |
-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
134 |
-
|
135 |
-
# decoder layers
|
136 |
-
all_hidden_states = () if output_hidden_states else None
|
137 |
-
all_self_attns = () if output_attentions else None
|
138 |
-
next_decoder_cache = None
|
139 |
-
|
140 |
-
for idx, decoder_layer in enumerate(self.layers):
|
141 |
-
if output_hidden_states:
|
142 |
-
all_hidden_states += (hidden_states,)
|
143 |
-
|
144 |
-
# For text-only path we should skip cross attention layers.
|
145 |
-
# Let's check if the layer is cross attention layer and if we have cross attention states
|
146 |
-
# or cached cross attention states.
|
147 |
-
is_cross_attention_layer = idx in self.cross_attention_layers
|
148 |
-
is_cross_attention_cache_empty = past_key_values is None or (
|
149 |
-
past_key_values is not None and past_key_values.get_seq_length(idx) == 0
|
150 |
-
)
|
151 |
-
|
152 |
-
if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
|
153 |
-
continue
|
154 |
-
|
155 |
-
if self.gradient_checkpointing and self.training:
|
156 |
-
layer_outputs = self._gradient_checkpointing_func(
|
157 |
-
decoder_layer.__call__,
|
158 |
-
hidden_states,
|
159 |
-
cross_attention_states,
|
160 |
-
cross_attention_mask,
|
161 |
-
causal_mask,
|
162 |
-
full_text_row_masked_out_mask,
|
163 |
-
position_ids,
|
164 |
-
past_key_values,
|
165 |
-
output_attentions,
|
166 |
-
use_cache,
|
167 |
-
cache_position,
|
168 |
-
position_embeddings,
|
169 |
-
)
|
170 |
-
else:
|
171 |
-
layer_outputs = decoder_layer(
|
172 |
-
hidden_states,
|
173 |
-
cross_attention_states=cross_attention_states,
|
174 |
-
cross_attention_mask=cross_attention_mask,
|
175 |
-
attention_mask=causal_mask,
|
176 |
-
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
177 |
-
position_ids=position_ids,
|
178 |
-
past_key_value=past_key_values,
|
179 |
-
output_attentions=output_attentions,
|
180 |
-
use_cache=use_cache,
|
181 |
-
cache_position=cache_position,
|
182 |
-
position_embeddings=position_embeddings,
|
183 |
-
)
|
184 |
-
|
185 |
-
hidden_states = layer_outputs[0]
|
186 |
-
|
187 |
-
if use_cache:
|
188 |
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
189 |
-
|
190 |
-
if output_attentions:
|
191 |
-
all_self_attns += (layer_outputs[1],)
|
192 |
-
|
193 |
-
hidden_states = self.norm(hidden_states)
|
194 |
-
|
195 |
-
# add hidden states from the last decoder layer
|
196 |
-
if output_hidden_states:
|
197 |
-
all_hidden_states += (hidden_states,)
|
198 |
-
|
199 |
-
next_cache = next_decoder_cache if use_cache else None
|
200 |
-
|
201 |
-
if not return_dict:
|
202 |
-
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
203 |
-
return BaseModelOutputWithPast(
|
204 |
-
last_hidden_state=hidden_states,
|
205 |
-
past_key_values=next_cache,
|
206 |
-
hidden_states=all_hidden_states,
|
207 |
-
attentions=all_self_attns,
|
208 |
-
)
|
209 |
-
|
210 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
211 |
-
def _update_causal_mask(
|
212 |
-
self,
|
213 |
-
attention_mask: torch.Tensor,
|
214 |
-
input_tensor: torch.Tensor,
|
215 |
-
cache_position: torch.Tensor,
|
216 |
-
past_key_values: Cache,
|
217 |
-
output_attentions: bool,
|
218 |
-
):
|
219 |
-
if self.config._attn_implementation == "flash_attention_2":
|
220 |
-
if attention_mask is not None and 0.0 in attention_mask:
|
221 |
-
return attention_mask
|
222 |
-
return None
|
223 |
-
|
224 |
-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
225 |
-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
226 |
-
# to infer the attention mask.
|
227 |
-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
228 |
-
using_static_cache = isinstance(past_key_values, StaticCache)
|
229 |
-
|
230 |
-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
231 |
-
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
232 |
-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
233 |
-
attention_mask,
|
234 |
-
inputs_embeds=input_tensor,
|
235 |
-
past_key_values_length=past_seen_tokens,
|
236 |
-
is_training=self.training,
|
237 |
-
):
|
238 |
-
return None
|
239 |
-
|
240 |
-
dtype, device = input_tensor.dtype, input_tensor.device
|
241 |
-
sequence_length = input_tensor.shape[1]
|
242 |
-
if using_static_cache:
|
243 |
-
target_length = past_key_values.get_max_cache_shape()
|
244 |
-
else:
|
245 |
-
target_length = (
|
246 |
-
attention_mask.shape[-1]
|
247 |
-
if isinstance(attention_mask, torch.Tensor)
|
248 |
-
else past_seen_tokens + sequence_length + 1
|
249 |
-
)
|
250 |
-
|
251 |
-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
252 |
-
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
253 |
-
attention_mask,
|
254 |
-
sequence_length=sequence_length,
|
255 |
-
target_length=target_length,
|
256 |
-
dtype=dtype,
|
257 |
-
device=device,
|
258 |
-
cache_position=cache_position,
|
259 |
-
batch_size=input_tensor.shape[0],
|
260 |
-
)
|
261 |
-
|
262 |
-
if (
|
263 |
-
self.config._attn_implementation == "sdpa"
|
264 |
-
and attention_mask is not None
|
265 |
-
and attention_mask.device.type == "cuda"
|
266 |
-
and not output_attentions
|
267 |
-
):
|
268 |
-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
269 |
-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
270 |
-
# Details: https://github.com/pytorch/pytorch/issues/110213
|
271 |
-
min_dtype = torch.finfo(dtype).min
|
272 |
-
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
273 |
-
|
274 |
-
return causal_mask
|
275 |
-
|
276 |
-
@staticmethod
|
277 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
278 |
-
def _prepare_4d_causal_attention_mask_with_cache_position(
|
279 |
-
attention_mask: torch.Tensor,
|
280 |
-
sequence_length: int,
|
281 |
-
target_length: int,
|
282 |
-
dtype: torch.dtype,
|
283 |
-
device: torch.device,
|
284 |
-
cache_position: torch.Tensor,
|
285 |
-
batch_size: int,
|
286 |
-
**kwargs,
|
287 |
-
):
|
288 |
-
if attention_mask is not None and attention_mask.dim() == 4:
|
289 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
290 |
-
causal_mask = attention_mask
|
291 |
-
else:
|
292 |
-
min_dtype = torch.finfo(dtype).min
|
293 |
-
causal_mask = torch.full(
|
294 |
-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
295 |
-
)
|
296 |
-
if sequence_length != 1:
|
297 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
298 |
-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
299 |
-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
300 |
-
if attention_mask is not None:
|
301 |
-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
302 |
-
mask_length = attention_mask.shape[-1]
|
303 |
-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
304 |
-
padding_mask = padding_mask == 0
|
305 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
306 |
-
padding_mask, min_dtype
|
307 |
-
)
|
308 |
-
|
309 |
-
return causal_mask
|
310 |
-
|
311 |
-
class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
312 |
-
config_class = MllamaTextConfig
|
313 |
-
base_model_prefix = "model"
|
314 |
-
#_tied_weights_keys = ["lm_head.weight"]
|
315 |
-
|
316 |
-
def __init__(self, config: MllamaTextConfig):
|
317 |
-
super().__init__(config.get_text_config())
|
318 |
-
self.text_config = config.get_text_config()
|
319 |
-
self.vocab_size = self.text_config.vocab_size
|
320 |
-
self.model = Llama3TextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
|
321 |
-
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
|
322 |
-
|
323 |
-
self.post_init()
|
324 |
-
|
325 |
-
def get_input_embeddings(self):
|
326 |
-
#return self.model.embed_tokens.text_embeddings
|
327 |
-
return None
|
328 |
-
|
329 |
-
def set_input_embeddings(self, value):
|
330 |
-
#self.model.embed_tokens.text_embeddings = value
|
331 |
-
pass
|
332 |
-
|
333 |
-
def get_output_embeddings(self):
|
334 |
-
return self.lm_head
|
335 |
-
|
336 |
-
def set_output_embeddings(self, new_embeddings):
|
337 |
-
self.lm_head = new_embeddings
|
338 |
-
|
339 |
-
def set_decoder(self, decoder):
|
340 |
-
self.model = decoder
|
341 |
-
|
342 |
-
def get_decoder(self):
|
343 |
-
return self.model
|
344 |
-
|
345 |
-
def forward(
|
346 |
-
self,
|
347 |
-
#input_ids: torch.LongTensor = None,
|
348 |
-
attention_mask: Optional[torch.Tensor] = None,
|
349 |
-
position_ids: Optional[torch.LongTensor] = None,
|
350 |
-
cross_attention_states: Optional[torch.LongTensor] = None,
|
351 |
-
cross_attention_mask: Optional[torch.LongTensor] = None,
|
352 |
-
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
353 |
-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
354 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
355 |
-
labels: Optional[torch.LongTensor] = None,
|
356 |
-
use_cache: Optional[bool] = None,
|
357 |
-
output_attentions: Optional[bool] = None,
|
358 |
-
output_hidden_states: Optional[bool] = None,
|
359 |
-
return_dict: Optional[bool] = None,
|
360 |
-
cache_position: Optional[torch.LongTensor] = None,
|
361 |
-
num_logits_to_keep: int = 0,
|
362 |
-
**loss_kwargs,
|
363 |
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
364 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
365 |
-
output_hidden_states = (
|
366 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
367 |
-
)
|
368 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
369 |
-
|
370 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
371 |
-
outputs = self.model(
|
372 |
-
#input_ids=input_ids,
|
373 |
-
cross_attention_states=cross_attention_states,
|
374 |
-
attention_mask=attention_mask,
|
375 |
-
position_ids=position_ids,
|
376 |
-
cross_attention_mask=cross_attention_mask,
|
377 |
-
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
378 |
-
past_key_values=past_key_values,
|
379 |
-
inputs_embeds=inputs_embeds,
|
380 |
-
use_cache=use_cache,
|
381 |
-
output_attentions=output_attentions,
|
382 |
-
output_hidden_states=output_hidden_states,
|
383 |
-
return_dict=return_dict,
|
384 |
-
cache_position=cache_position,
|
385 |
-
)
|
386 |
-
|
387 |
-
hidden_states = outputs[0]
|
388 |
-
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
389 |
-
|
390 |
-
loss = None
|
391 |
-
if labels is not None:
|
392 |
-
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
393 |
-
|
394 |
-
if not return_dict:
|
395 |
-
output = (logits,) + outputs[1:]
|
396 |
-
return (loss,) + output if loss is not None else output
|
397 |
-
|
398 |
-
return CausalLMOutputWithPast(
|
399 |
-
loss=loss,
|
400 |
-
logits=logits,
|
401 |
-
past_key_values=outputs.past_key_values,
|
402 |
-
hidden_states=outputs.hidden_states,
|
403 |
-
attentions=outputs.attentions,
|
404 |
-
)
|
405 |
-
|
406 |
-
#AutoModelForCausalLM.register(MllamaTextConfig, Llama3ForCausalLM)
|
407 |
-
#transformers.Llama3ForCausalLM = Llama3ForCausalLM
|
408 |
|
409 |
class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
410 |
config_class = Llama3Config
|
@@ -421,7 +33,6 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|
421 |
|
422 |
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
|
423 |
self.language_model = MllamaForCausalLM._from_config(config.text_config)
|
424 |
-
self.language_model.get_input_embeddings().weight.required_grad = False
|
425 |
self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
|
426 |
self.multi_modal_projector = nn.Linear(
|
427 |
config.vision_config.vision_output_dim,
|
@@ -431,10 +42,10 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|
431 |
self.post_init()
|
432 |
|
433 |
def get_input_embeddings(self):
|
434 |
-
return self.
|
435 |
|
436 |
def set_input_embeddings(self, value):
|
437 |
-
self.
|
438 |
|
439 |
def get_output_embeddings(self):
|
440 |
return self.language_model.get_output_embeddings()
|
@@ -565,8 +176,8 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|
565 |
inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
566 |
|
567 |
outputs = self.language_model(
|
568 |
-
|
569 |
-
|
570 |
attention_mask=attention_mask,
|
571 |
position_ids=position_ids,
|
572 |
cross_attention_states=cross_attention_states,
|
@@ -670,5 +281,5 @@ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
|
670 |
)
|
671 |
return model_kwargs
|
672 |
|
673 |
-
|
674 |
-
|
|
|
6 |
from torch import nn
|
7 |
|
8 |
import transformers
|
9 |
+
from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, AutoModel
|
|
|
10 |
from transformers.generation import GenerationMixin
|
11 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
12 |
from transformers.utils import logging
|
13 |
+
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask
|
|
|
14 |
from .configuration_llama3 import Llama3Config
|
15 |
from .mllama_audio_model import Llama3Embedding
|
16 |
|
17 |
|
18 |
logger = logging.get_logger(__name__)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
22 |
config_class = Llama3Config
|
|
|
33 |
|
34 |
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
|
35 |
self.language_model = MllamaForCausalLM._from_config(config.text_config)
|
|
|
36 |
self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
|
37 |
self.multi_modal_projector = nn.Linear(
|
38 |
config.vision_config.vision_output_dim,
|
|
|
42 |
self.post_init()
|
43 |
|
44 |
def get_input_embeddings(self):
|
45 |
+
return self.embed_tokens.text_embeddings
|
46 |
|
47 |
def set_input_embeddings(self, value):
|
48 |
+
self.embed_tokens.text_embeddings = value
|
49 |
|
50 |
def get_output_embeddings(self):
|
51 |
return self.language_model.get_output_embeddings()
|
|
|
176 |
inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
177 |
|
178 |
outputs = self.language_model(
|
179 |
+
input_ids=None,
|
180 |
+
audio_features=None,
|
181 |
attention_mask=attention_mask,
|
182 |
position_ids=position_ids,
|
183 |
cross_attention_states=cross_attention_states,
|
|
|
281 |
)
|
282 |
return model_kwargs
|
283 |
|
284 |
+
AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
|
285 |
+
transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration
|