AlexHung29629
commited on
Commit
•
3889df6
1
Parent(s):
368067b
Upload 5 files
Browse files- audio_processing_mllama.py +61 -0
- configuration_llama3.py +112 -0
- mllama_audio_model.py +74 -0
- modeling_llama3.py +674 -0
- processing_mllama.py +377 -0
audio_processing_mllama.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
+
import numpy as np
|
4 |
+
import transformers
|
5 |
+
from transformers.tokenization_utils_base import AudioInput
|
6 |
+
from transformers.models.seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
|
7 |
+
from transformers.utils import TensorType
|
8 |
+
from transformers.feature_extraction_utils import BatchFeature
|
9 |
+
from transformers import AutoFeatureExtractor
|
10 |
+
|
11 |
+
|
12 |
+
def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict:
|
13 |
+
if not isinstance(audio_features, list):
|
14 |
+
audio_features = audio_features['audio_features']
|
15 |
+
bs = audio_features.shape[0]
|
16 |
+
for i in range(bs):
|
17 |
+
for j in range(len(audio_features[i])):
|
18 |
+
tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0])
|
19 |
+
text[i] = text[i].replace(audio_token, tgt_token, 1)
|
20 |
+
return text
|
21 |
+
|
22 |
+
def get_num_embeddings(num_framses, adapter_kernel_size=7, adapter_stride=4) -> int:
|
23 |
+
return math.ceil((num_framses - adapter_kernel_size) / adapter_stride) + 1 + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
|
24 |
+
|
25 |
+
|
26 |
+
class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
|
27 |
+
|
28 |
+
def __call__(
|
29 |
+
self,
|
30 |
+
batch_audio_clips: List[List[AudioInput]],
|
31 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
32 |
+
) -> BatchFeature:
|
33 |
+
audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_features'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ]
|
34 |
+
packed_audio_features = self.pack_audio_clips(audio_features)
|
35 |
+
|
36 |
+
encoded_audio_inputs = BatchFeature(
|
37 |
+
data={
|
38 |
+
"audio_features": packed_audio_features,
|
39 |
+
},
|
40 |
+
tensor_type=return_tensors,
|
41 |
+
)
|
42 |
+
|
43 |
+
return encoded_audio_inputs
|
44 |
+
|
45 |
+
def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray:
|
46 |
+
assert batch_audio_clips[0][0].ndim == 2 # sequence length x feature dimension
|
47 |
+
# Determine output shape: (batch_size, max_num_clips, max_frames, feature_dim)
|
48 |
+
batch_size = len(batch_audio_clips)
|
49 |
+
max_num_clips = max([len(clips) for clips in batch_audio_clips])
|
50 |
+
max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips])
|
51 |
+
feature_dim = batch_audio_clips[0][0].shape[1]
|
52 |
+
|
53 |
+
stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames, feature_dim), dtype=np.float32)
|
54 |
+
for i, clips in enumerate(batch_audio_clips):
|
55 |
+
for j, clip in enumerate(clips):
|
56 |
+
stacked_audio_clips[i, j, :clip.shape[0], :] = clip
|
57 |
+
|
58 |
+
return stacked_audio_clips
|
59 |
+
|
60 |
+
AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor)
|
61 |
+
transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor
|
configuration_llama3.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Mllama model configuration"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
from typing import Dict, List, Optional, Union
|
18 |
+
|
19 |
+
import transformers
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
22 |
+
from transformers.utils import logging
|
23 |
+
from transformers import Wav2Vec2BertConfig, AutoConfig
|
24 |
+
from transformers.models.mllama.configuration_mllama import MllamaVisionConfig, MllamaTextConfig
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class Llama3Config(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
This is the configuration class to store the configuration of a [`MllamaForConditionalGeneration`]. It is used to instantiate an
|
32 |
+
Mllama model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
33 |
+
with the defaults will yield a similar configuration to that of the Mllama-9B.
|
34 |
+
|
35 |
+
e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
|
36 |
+
|
37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
38 |
+
documentation from [`PretrainedConfig`] for more information.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaVisionConfig`):
|
42 |
+
The config object or dictionary of the vision backbone.
|
43 |
+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaTextConfig`):
|
44 |
+
The config object or dictionary of the text backbone.
|
45 |
+
image_token_index (`int`, *optional*, defaults to 128256):
|
46 |
+
The image token index to encode the image prompt.
|
47 |
+
|
48 |
+
Example:
|
49 |
+
|
50 |
+
```python
|
51 |
+
>>> from transformers import MllamaForConditionalGeneration, MllamaConfig, MllamaVisionConfig, MllamaTextConfig
|
52 |
+
|
53 |
+
>>> # Initializing a CLIP-vision config
|
54 |
+
>>> vision_config = MllamaVisionConfig()
|
55 |
+
|
56 |
+
>>> # Initializing a Llama config
|
57 |
+
>>> text_config = MllamaTextConfig()
|
58 |
+
|
59 |
+
>>> # Initializing a mllama-11b style configuration
|
60 |
+
>>> configuration = MllamaConfig(vision_config, text_config)
|
61 |
+
|
62 |
+
>>> # Initializing a model from the mllama-11b style configuration
|
63 |
+
>>> model = MllamaForConditionalGeneration(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
|
69 |
+
model_type = "llama3"
|
70 |
+
is_composition = True
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
vision_config=None,
|
75 |
+
text_config=None,
|
76 |
+
audio_config=None,
|
77 |
+
image_token_index=128256,
|
78 |
+
audio_token_index=128257,
|
79 |
+
**kwargs,
|
80 |
+
):
|
81 |
+
if vision_config is None:
|
82 |
+
self.vision_config = MllamaVisionConfig()
|
83 |
+
logger.info("vision_config is None, using default mllama vision config")
|
84 |
+
elif isinstance(vision_config, dict):
|
85 |
+
self.vision_config = MllamaVisionConfig(**vision_config)
|
86 |
+
elif isinstance(vision_config, MllamaVisionConfig):
|
87 |
+
self.vision_config = vision_config
|
88 |
+
|
89 |
+
self.image_token_index = image_token_index
|
90 |
+
|
91 |
+
if audio_config is None:
|
92 |
+
self.audio_config = Wav2Vec2BertConfig()
|
93 |
+
logger.info("audio_config is None, using default mllama audio config")
|
94 |
+
elif isinstance(audio_config, dict):
|
95 |
+
self.audio_config = Wav2Vec2BertConfig(**audio_config)
|
96 |
+
elif isinstance(audio_config, Wav2Vec2BertConfig):
|
97 |
+
self.audio_config = audio_config
|
98 |
+
|
99 |
+
self.audio_token_index = audio_token_index
|
100 |
+
|
101 |
+
if text_config is None:
|
102 |
+
self.text_config = MllamaTextConfig()
|
103 |
+
logger.info("text_config is None, using default mllama text config")
|
104 |
+
elif isinstance(text_config, dict):
|
105 |
+
self.text_config = MllamaTextConfig(**text_config)
|
106 |
+
elif isinstance(text_config, MllamaTextConfig):
|
107 |
+
self.text_config = text_config
|
108 |
+
|
109 |
+
super().__init__(**kwargs)
|
110 |
+
|
111 |
+
AutoConfig.register("llama3", Llama3Config)
|
112 |
+
transformers.Llama3Config = Llama3Config
|
mllama_audio_model.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from transformers.modeling_outputs import BaseModelOutput
|
5 |
+
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, Wav2Vec2BertPreTrainedModel
|
6 |
+
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
7 |
+
|
8 |
+
|
9 |
+
class Llama3Embedding(Wav2Vec2BertPreTrainedModel):
|
10 |
+
base_model_prefix = "audio_model"
|
11 |
+
def __init__(self, config: Wav2Vec2BertConfig, text_config: MllamaTextConfig):
|
12 |
+
super().__init__(config)
|
13 |
+
assert config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
|
14 |
+
assert config.output_hidden_size == text_config.hidden_size
|
15 |
+
self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
|
16 |
+
self.audio_embedding = Wav2Vec2BertModel(config)
|
17 |
+
#assert self.text_embeddings.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
|
18 |
+
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
19 |
+
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
|
20 |
+
self.text_config = text_config
|
21 |
+
|
22 |
+
def _init_weights(self, module):
|
23 |
+
std = self.text_config.initializer_range
|
24 |
+
"""Initialize the weights"""
|
25 |
+
if isinstance(module, Wav2Vec2BertSelfAttention):
|
26 |
+
if hasattr(module, "pos_bias_u"):
|
27 |
+
nn.init.xavier_uniform_(module.pos_bias_u)
|
28 |
+
if hasattr(module, "pos_bias_v"):
|
29 |
+
nn.init.xavier_uniform_(module.pos_bias_v)
|
30 |
+
elif isinstance(module, Wav2Vec2BertFeatureProjection):
|
31 |
+
k = math.sqrt(1 / module.projection.in_features)
|
32 |
+
nn.init.uniform_(module.projection.weight, a=-k, b=k)
|
33 |
+
nn.init.uniform_(module.projection.bias, a=-k, b=k)
|
34 |
+
elif isinstance(module, nn.Linear):
|
35 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
36 |
+
|
37 |
+
if module.bias is not None:
|
38 |
+
module.bias.data.zero_()
|
39 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
|
40 |
+
module.bias.data.zero_()
|
41 |
+
module.weight.data.fill_(1.0)
|
42 |
+
elif isinstance(module, nn.Conv1d):
|
43 |
+
nn.init.kaiming_normal_(module.weight)
|
44 |
+
|
45 |
+
if module.bias is not None:
|
46 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
47 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
48 |
+
elif isinstance(module, nn.Embedding):
|
49 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
50 |
+
if module.padding_idx is not None:
|
51 |
+
module.weight.data[module.padding_idx].zero_()
|
52 |
+
elif isinstance(module, nn.Parameter):
|
53 |
+
module.data.normal_(mean=0.0, std=std)
|
54 |
+
|
55 |
+
def forward(
|
56 |
+
self,
|
57 |
+
input_ids: torch.LongTensor = None,
|
58 |
+
audio_features: Optional[torch.Tensor] = None,
|
59 |
+
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
|
60 |
+
input_embeddings = self.text_embeddings(torch.clamp(input_ids, min=0))
|
61 |
+
if audio_features is None:
|
62 |
+
return input_embeddings
|
63 |
+
bs, max_num_img, l, d = audio_features.shape
|
64 |
+
audio_embeddings = self.audio_embedding(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
|
65 |
+
audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
|
66 |
+
|
67 |
+
for i in range(bs):
|
68 |
+
for j in range(max_num_img):
|
69 |
+
audio_id = -1 - j
|
70 |
+
if torch.any(input_ids[i] == audio_id):
|
71 |
+
positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True)
|
72 |
+
seq_len = input_embeddings[i][positions].shape[0] - 2
|
73 |
+
input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :seq_len, :], self.end_of_audio]), accumulate=False)
|
74 |
+
return input_embeddings
|
modeling_llama3.py
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, Wav2Vec2BertConfig, AutoModel, AutoModelForCausalLM
|
10 |
+
from transformers.cache_utils import Cache, StaticCache
|
11 |
+
from transformers.generation import GenerationMixin
|
12 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
13 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
|
14 |
+
from transformers.utils import logging
|
15 |
+
from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
|
16 |
+
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
|
17 |
+
from .configuration_llama3 import Llama3Config
|
18 |
+
from .mllama_audio_model import Llama3Embedding
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
class Llama3PreTrainedModel(MllamaPreTrainedModel):
|
24 |
+
config_class = Llama3Config
|
25 |
+
base_model_prefix = "model"
|
26 |
+
|
27 |
+
class Llama3TextModel(MllamaPreTrainedModel):
|
28 |
+
config_class = MllamaTextConfig
|
29 |
+
base_model_prefix = "language_model.model"
|
30 |
+
|
31 |
+
def __init__(self, config: MllamaTextConfig):
|
32 |
+
super().__init__(config)
|
33 |
+
self.padding_idx = config.pad_token_id
|
34 |
+
self.vocab_size = config.vocab_size
|
35 |
+
#self.embed_tokens = Llama3Embedding(audio_config, config)
|
36 |
+
self.cross_attention_layers = config.cross_attention_layers
|
37 |
+
|
38 |
+
layers = []
|
39 |
+
for layer_idx in range(config.num_hidden_layers):
|
40 |
+
if layer_idx in self.cross_attention_layers:
|
41 |
+
layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
|
42 |
+
else:
|
43 |
+
layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
|
44 |
+
|
45 |
+
self.layers = nn.ModuleList(layers)
|
46 |
+
self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
47 |
+
self.rotary_emb = MllamaRotaryEmbedding(config=config)
|
48 |
+
self.gradient_checkpointing = False
|
49 |
+
self.post_init()
|
50 |
+
|
51 |
+
def get_input_embeddings(self):
|
52 |
+
#return self.embed_tokens.text_embeddings
|
53 |
+
return None
|
54 |
+
|
55 |
+
def set_input_embeddings(self, value):
|
56 |
+
#self.embed_tokens.text_embeddings = value
|
57 |
+
pass
|
58 |
+
|
59 |
+
def forward(
|
60 |
+
self,
|
61 |
+
#input_ids: Optional[torch.LongTensor] = None,
|
62 |
+
#audio_features: Optional[torch.Tensor] = None,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
position_ids: Optional[torch.LongTensor] = None,
|
65 |
+
cross_attention_states: Optional[torch.FloatTensor] = None,
|
66 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
67 |
+
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
68 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
69 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
70 |
+
use_cache: Optional[bool] = None,
|
71 |
+
output_attentions: Optional[bool] = None,
|
72 |
+
output_hidden_states: Optional[bool] = None,
|
73 |
+
return_dict: Optional[bool] = None,
|
74 |
+
cache_position: Optional[torch.LongTensor] = None,
|
75 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
76 |
+
"""
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
|
80 |
+
Example:
|
81 |
+
|
82 |
+
```python
|
83 |
+
>>> from transformers import AutoProcessor, MllamaTextModel
|
84 |
+
|
85 |
+
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
86 |
+
>>> model = MllamaTextModel.from_pretrained(checkpoint)
|
87 |
+
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
88 |
+
|
89 |
+
>>> text = "<|image|>If I had to write a haiku for this one"
|
90 |
+
>>> inputs = processor(text=text, return_tensors="pt")
|
91 |
+
|
92 |
+
>>> output = model(**inputs)
|
93 |
+
|
94 |
+
>>> print(output.last_hidden_state.shape)
|
95 |
+
torch.Size([1, 13, 4096])
|
96 |
+
```
|
97 |
+
"""
|
98 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
99 |
+
output_hidden_states = (
|
100 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
101 |
+
)
|
102 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
103 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
104 |
+
|
105 |
+
#if (input_ids is None) ^ (inputs_embeds is not None):
|
106 |
+
# raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
107 |
+
|
108 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
109 |
+
logger.warning_once(
|
110 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
111 |
+
)
|
112 |
+
use_cache = False
|
113 |
+
|
114 |
+
#if inputs_embeds is None:
|
115 |
+
# inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
116 |
+
|
117 |
+
|
118 |
+
hidden_states = inputs_embeds
|
119 |
+
|
120 |
+
if cache_position is None:
|
121 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
122 |
+
cache_position = torch.arange(
|
123 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
124 |
+
)
|
125 |
+
if position_ids is None:
|
126 |
+
position_ids = cache_position.unsqueeze(0)
|
127 |
+
|
128 |
+
causal_mask = self._update_causal_mask(
|
129 |
+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
130 |
+
)
|
131 |
+
|
132 |
+
# create position embeddings to be shared across the decoder layers
|
133 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
134 |
+
|
135 |
+
# decoder layers
|
136 |
+
all_hidden_states = () if output_hidden_states else None
|
137 |
+
all_self_attns = () if output_attentions else None
|
138 |
+
next_decoder_cache = None
|
139 |
+
|
140 |
+
for idx, decoder_layer in enumerate(self.layers):
|
141 |
+
if output_hidden_states:
|
142 |
+
all_hidden_states += (hidden_states,)
|
143 |
+
|
144 |
+
# For text-only path we should skip cross attention layers.
|
145 |
+
# Let's check if the layer is cross attention layer and if we have cross attention states
|
146 |
+
# or cached cross attention states.
|
147 |
+
is_cross_attention_layer = idx in self.cross_attention_layers
|
148 |
+
is_cross_attention_cache_empty = past_key_values is None or (
|
149 |
+
past_key_values is not None and past_key_values.get_seq_length(idx) == 0
|
150 |
+
)
|
151 |
+
|
152 |
+
if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
|
153 |
+
continue
|
154 |
+
|
155 |
+
if self.gradient_checkpointing and self.training:
|
156 |
+
layer_outputs = self._gradient_checkpointing_func(
|
157 |
+
decoder_layer.__call__,
|
158 |
+
hidden_states,
|
159 |
+
cross_attention_states,
|
160 |
+
cross_attention_mask,
|
161 |
+
causal_mask,
|
162 |
+
full_text_row_masked_out_mask,
|
163 |
+
position_ids,
|
164 |
+
past_key_values,
|
165 |
+
output_attentions,
|
166 |
+
use_cache,
|
167 |
+
cache_position,
|
168 |
+
position_embeddings,
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
layer_outputs = decoder_layer(
|
172 |
+
hidden_states,
|
173 |
+
cross_attention_states=cross_attention_states,
|
174 |
+
cross_attention_mask=cross_attention_mask,
|
175 |
+
attention_mask=causal_mask,
|
176 |
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
177 |
+
position_ids=position_ids,
|
178 |
+
past_key_value=past_key_values,
|
179 |
+
output_attentions=output_attentions,
|
180 |
+
use_cache=use_cache,
|
181 |
+
cache_position=cache_position,
|
182 |
+
position_embeddings=position_embeddings,
|
183 |
+
)
|
184 |
+
|
185 |
+
hidden_states = layer_outputs[0]
|
186 |
+
|
187 |
+
if use_cache:
|
188 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
189 |
+
|
190 |
+
if output_attentions:
|
191 |
+
all_self_attns += (layer_outputs[1],)
|
192 |
+
|
193 |
+
hidden_states = self.norm(hidden_states)
|
194 |
+
|
195 |
+
# add hidden states from the last decoder layer
|
196 |
+
if output_hidden_states:
|
197 |
+
all_hidden_states += (hidden_states,)
|
198 |
+
|
199 |
+
next_cache = next_decoder_cache if use_cache else None
|
200 |
+
|
201 |
+
if not return_dict:
|
202 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
203 |
+
return BaseModelOutputWithPast(
|
204 |
+
last_hidden_state=hidden_states,
|
205 |
+
past_key_values=next_cache,
|
206 |
+
hidden_states=all_hidden_states,
|
207 |
+
attentions=all_self_attns,
|
208 |
+
)
|
209 |
+
|
210 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
211 |
+
def _update_causal_mask(
|
212 |
+
self,
|
213 |
+
attention_mask: torch.Tensor,
|
214 |
+
input_tensor: torch.Tensor,
|
215 |
+
cache_position: torch.Tensor,
|
216 |
+
past_key_values: Cache,
|
217 |
+
output_attentions: bool,
|
218 |
+
):
|
219 |
+
if self.config._attn_implementation == "flash_attention_2":
|
220 |
+
if attention_mask is not None and 0.0 in attention_mask:
|
221 |
+
return attention_mask
|
222 |
+
return None
|
223 |
+
|
224 |
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
225 |
+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
226 |
+
# to infer the attention mask.
|
227 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
228 |
+
using_static_cache = isinstance(past_key_values, StaticCache)
|
229 |
+
|
230 |
+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
231 |
+
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
232 |
+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
233 |
+
attention_mask,
|
234 |
+
inputs_embeds=input_tensor,
|
235 |
+
past_key_values_length=past_seen_tokens,
|
236 |
+
is_training=self.training,
|
237 |
+
):
|
238 |
+
return None
|
239 |
+
|
240 |
+
dtype, device = input_tensor.dtype, input_tensor.device
|
241 |
+
sequence_length = input_tensor.shape[1]
|
242 |
+
if using_static_cache:
|
243 |
+
target_length = past_key_values.get_max_cache_shape()
|
244 |
+
else:
|
245 |
+
target_length = (
|
246 |
+
attention_mask.shape[-1]
|
247 |
+
if isinstance(attention_mask, torch.Tensor)
|
248 |
+
else past_seen_tokens + sequence_length + 1
|
249 |
+
)
|
250 |
+
|
251 |
+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
252 |
+
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
253 |
+
attention_mask,
|
254 |
+
sequence_length=sequence_length,
|
255 |
+
target_length=target_length,
|
256 |
+
dtype=dtype,
|
257 |
+
device=device,
|
258 |
+
cache_position=cache_position,
|
259 |
+
batch_size=input_tensor.shape[0],
|
260 |
+
)
|
261 |
+
|
262 |
+
if (
|
263 |
+
self.config._attn_implementation == "sdpa"
|
264 |
+
and attention_mask is not None
|
265 |
+
and attention_mask.device.type == "cuda"
|
266 |
+
and not output_attentions
|
267 |
+
):
|
268 |
+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
269 |
+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
270 |
+
# Details: https://github.com/pytorch/pytorch/issues/110213
|
271 |
+
min_dtype = torch.finfo(dtype).min
|
272 |
+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
273 |
+
|
274 |
+
return causal_mask
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
278 |
+
def _prepare_4d_causal_attention_mask_with_cache_position(
|
279 |
+
attention_mask: torch.Tensor,
|
280 |
+
sequence_length: int,
|
281 |
+
target_length: int,
|
282 |
+
dtype: torch.dtype,
|
283 |
+
device: torch.device,
|
284 |
+
cache_position: torch.Tensor,
|
285 |
+
batch_size: int,
|
286 |
+
**kwargs,
|
287 |
+
):
|
288 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
289 |
+
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
290 |
+
causal_mask = attention_mask
|
291 |
+
else:
|
292 |
+
min_dtype = torch.finfo(dtype).min
|
293 |
+
causal_mask = torch.full(
|
294 |
+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
295 |
+
)
|
296 |
+
if sequence_length != 1:
|
297 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
298 |
+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
299 |
+
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
300 |
+
if attention_mask is not None:
|
301 |
+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
302 |
+
mask_length = attention_mask.shape[-1]
|
303 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
304 |
+
padding_mask = padding_mask == 0
|
305 |
+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
306 |
+
padding_mask, min_dtype
|
307 |
+
)
|
308 |
+
|
309 |
+
return causal_mask
|
310 |
+
|
311 |
+
class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
|
312 |
+
config_class = MllamaTextConfig
|
313 |
+
base_model_prefix = "model"
|
314 |
+
#_tied_weights_keys = ["lm_head.weight"]
|
315 |
+
|
316 |
+
def __init__(self, config: MllamaTextConfig):
|
317 |
+
super().__init__(config.get_text_config())
|
318 |
+
self.text_config = config.get_text_config()
|
319 |
+
self.vocab_size = self.text_config.vocab_size
|
320 |
+
self.model = Llama3TextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
|
321 |
+
self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
|
322 |
+
|
323 |
+
self.post_init()
|
324 |
+
|
325 |
+
def get_input_embeddings(self):
|
326 |
+
#return self.model.embed_tokens.text_embeddings
|
327 |
+
return None
|
328 |
+
|
329 |
+
def set_input_embeddings(self, value):
|
330 |
+
#self.model.embed_tokens.text_embeddings = value
|
331 |
+
pass
|
332 |
+
|
333 |
+
def get_output_embeddings(self):
|
334 |
+
return self.lm_head
|
335 |
+
|
336 |
+
def set_output_embeddings(self, new_embeddings):
|
337 |
+
self.lm_head = new_embeddings
|
338 |
+
|
339 |
+
def set_decoder(self, decoder):
|
340 |
+
self.model = decoder
|
341 |
+
|
342 |
+
def get_decoder(self):
|
343 |
+
return self.model
|
344 |
+
|
345 |
+
def forward(
|
346 |
+
self,
|
347 |
+
#input_ids: torch.LongTensor = None,
|
348 |
+
attention_mask: Optional[torch.Tensor] = None,
|
349 |
+
position_ids: Optional[torch.LongTensor] = None,
|
350 |
+
cross_attention_states: Optional[torch.LongTensor] = None,
|
351 |
+
cross_attention_mask: Optional[torch.LongTensor] = None,
|
352 |
+
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
353 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
354 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
355 |
+
labels: Optional[torch.LongTensor] = None,
|
356 |
+
use_cache: Optional[bool] = None,
|
357 |
+
output_attentions: Optional[bool] = None,
|
358 |
+
output_hidden_states: Optional[bool] = None,
|
359 |
+
return_dict: Optional[bool] = None,
|
360 |
+
cache_position: Optional[torch.LongTensor] = None,
|
361 |
+
num_logits_to_keep: int = 0,
|
362 |
+
**loss_kwargs,
|
363 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
364 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
365 |
+
output_hidden_states = (
|
366 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
367 |
+
)
|
368 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
369 |
+
|
370 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
371 |
+
outputs = self.model(
|
372 |
+
#input_ids=input_ids,
|
373 |
+
cross_attention_states=cross_attention_states,
|
374 |
+
attention_mask=attention_mask,
|
375 |
+
position_ids=position_ids,
|
376 |
+
cross_attention_mask=cross_attention_mask,
|
377 |
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
378 |
+
past_key_values=past_key_values,
|
379 |
+
inputs_embeds=inputs_embeds,
|
380 |
+
use_cache=use_cache,
|
381 |
+
output_attentions=output_attentions,
|
382 |
+
output_hidden_states=output_hidden_states,
|
383 |
+
return_dict=return_dict,
|
384 |
+
cache_position=cache_position,
|
385 |
+
)
|
386 |
+
|
387 |
+
hidden_states = outputs[0]
|
388 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
|
389 |
+
|
390 |
+
loss = None
|
391 |
+
if labels is not None:
|
392 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
393 |
+
|
394 |
+
if not return_dict:
|
395 |
+
output = (logits,) + outputs[1:]
|
396 |
+
return (loss,) + output if loss is not None else output
|
397 |
+
|
398 |
+
return CausalLMOutputWithPast(
|
399 |
+
loss=loss,
|
400 |
+
logits=logits,
|
401 |
+
past_key_values=outputs.past_key_values,
|
402 |
+
hidden_states=outputs.hidden_states,
|
403 |
+
attentions=outputs.attentions,
|
404 |
+
)
|
405 |
+
|
406 |
+
#AutoModelForCausalLM.register(MllamaTextConfig, Llama3ForCausalLM)
|
407 |
+
#transformers.Llama3ForCausalLM = Llama3ForCausalLM
|
408 |
+
|
409 |
+
class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
410 |
+
config_class = Llama3Config
|
411 |
+
base_model_prefix = "model"
|
412 |
+
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
|
413 |
+
|
414 |
+
def __init__(self, config: Llama3Config):
|
415 |
+
super().__init__(config)
|
416 |
+
self.vocab_size = config.text_config.vocab_size
|
417 |
+
self.hidden_size = config.text_config.hidden_size
|
418 |
+
self.max_num_tiles = config.vision_config.max_num_tiles
|
419 |
+
self.vision_output_dim = config.vision_config.vision_output_dim
|
420 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
421 |
+
|
422 |
+
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
|
423 |
+
self.language_model = MllamaForCausalLM._from_config(config.text_config)
|
424 |
+
self.language_model.get_input_embeddings().weight.required_grad = False
|
425 |
+
self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
|
426 |
+
self.multi_modal_projector = nn.Linear(
|
427 |
+
config.vision_config.vision_output_dim,
|
428 |
+
config.text_config.hidden_size,
|
429 |
+
bias=True,
|
430 |
+
)
|
431 |
+
self.post_init()
|
432 |
+
|
433 |
+
def get_input_embeddings(self):
|
434 |
+
return self.language_model.get_input_embeddings()
|
435 |
+
|
436 |
+
def set_input_embeddings(self, value):
|
437 |
+
self.language_model.set_input_embeddings(value)
|
438 |
+
|
439 |
+
def get_output_embeddings(self):
|
440 |
+
return self.language_model.get_output_embeddings()
|
441 |
+
|
442 |
+
def set_output_embeddings(self, new_embeddings):
|
443 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
444 |
+
|
445 |
+
def set_decoder(self, decoder):
|
446 |
+
self.language_model.set_decoder(decoder)
|
447 |
+
|
448 |
+
def get_decoder(self):
|
449 |
+
return self.language_model.get_decoder()
|
450 |
+
|
451 |
+
def tie_weights(self):
|
452 |
+
return self.language_model.tie_weights()
|
453 |
+
|
454 |
+
def forward(
|
455 |
+
self,
|
456 |
+
input_ids: Optional[torch.LongTensor] = None,
|
457 |
+
audio_features: Optional[torch.FloatTensor] = None,
|
458 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
459 |
+
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
460 |
+
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
461 |
+
attention_mask: Optional[torch.Tensor] = None,
|
462 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
463 |
+
cross_attention_states: Optional[torch.Tensor] = None,
|
464 |
+
position_ids: Optional[torch.LongTensor] = None,
|
465 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
466 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
467 |
+
labels: Optional[torch.LongTensor] = None,
|
468 |
+
use_cache: Optional[bool] = None,
|
469 |
+
output_attentions: Optional[bool] = None,
|
470 |
+
output_hidden_states: Optional[bool] = None,
|
471 |
+
return_dict: Optional[bool] = None,
|
472 |
+
cache_position: Optional[torch.LongTensor] = None,
|
473 |
+
num_logits_to_keep: int = 0,
|
474 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
475 |
+
r"""
|
476 |
+
Args:
|
477 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
478 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
479 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
480 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
481 |
+
|
482 |
+
num_logits_to_keep (`int`, *optional*):
|
483 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
484 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
485 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
486 |
+
|
487 |
+
|
488 |
+
Returns:
|
489 |
+
|
490 |
+
Example:
|
491 |
+
|
492 |
+
```python
|
493 |
+
>>> from PIL import Image
|
494 |
+
>>> import requests
|
495 |
+
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
496 |
+
|
497 |
+
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
498 |
+
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
499 |
+
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
500 |
+
|
501 |
+
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
502 |
+
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
503 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
504 |
+
|
505 |
+
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
506 |
+
|
507 |
+
>>> # Generate
|
508 |
+
>>> output = model.generate(**inputs, max_new_tokens=15)
|
509 |
+
|
510 |
+
>>> prompt_len = inputs.input_ids.shape[-1]
|
511 |
+
>>> generated_ids = output[:, prompt_len:]
|
512 |
+
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
513 |
+
>>> print(generated_text)
|
514 |
+
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
515 |
+
```
|
516 |
+
"""
|
517 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions
|
518 |
+
output_hidden_states = (
|
519 |
+
output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
|
520 |
+
)
|
521 |
+
return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict
|
522 |
+
|
523 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
524 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
525 |
+
|
526 |
+
if pixel_values is not None and inputs_embeds is not None:
|
527 |
+
raise ValueError(
|
528 |
+
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
529 |
+
)
|
530 |
+
|
531 |
+
if pixel_values is not None and cross_attention_states is not None:
|
532 |
+
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
|
533 |
+
|
534 |
+
if pixel_values is not None:
|
535 |
+
if aspect_ratio_ids is None:
|
536 |
+
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
|
537 |
+
# get vision tokens from vision model
|
538 |
+
vision_outputs = self.vision_model(
|
539 |
+
pixel_values=pixel_values,
|
540 |
+
aspect_ratio_ids=aspect_ratio_ids,
|
541 |
+
aspect_ratio_mask=aspect_ratio_mask,
|
542 |
+
output_hidden_states=output_hidden_states,
|
543 |
+
output_attentions=output_attentions,
|
544 |
+
return_dict=return_dict,
|
545 |
+
)
|
546 |
+
cross_attention_states = vision_outputs[0]
|
547 |
+
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
|
548 |
+
-1, cross_attention_states.shape[-2], self.hidden_size
|
549 |
+
)
|
550 |
+
|
551 |
+
if cross_attention_mask is not None:
|
552 |
+
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
|
553 |
+
cross_attention_mask,
|
554 |
+
num_vision_tokens=self.vision_model.num_patches,
|
555 |
+
dtype=self.dtype,
|
556 |
+
)
|
557 |
+
else:
|
558 |
+
full_text_row_masked_out_mask = None
|
559 |
+
|
560 |
+
if cross_attention_mask is not None and cache_position is not None:
|
561 |
+
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
562 |
+
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
|
563 |
+
|
564 |
+
if inputs_embeds is None:
|
565 |
+
inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
|
566 |
+
|
567 |
+
outputs = self.language_model(
|
568 |
+
#input_ids=input_ids,
|
569 |
+
#audio_features=audio_features,
|
570 |
+
attention_mask=attention_mask,
|
571 |
+
position_ids=position_ids,
|
572 |
+
cross_attention_states=cross_attention_states,
|
573 |
+
cross_attention_mask=cross_attention_mask,
|
574 |
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
575 |
+
past_key_values=past_key_values,
|
576 |
+
use_cache=use_cache,
|
577 |
+
inputs_embeds=inputs_embeds,
|
578 |
+
labels=labels,
|
579 |
+
output_hidden_states=output_hidden_states,
|
580 |
+
output_attentions=output_attentions,
|
581 |
+
return_dict=return_dict,
|
582 |
+
cache_position=cache_position,
|
583 |
+
num_logits_to_keep=num_logits_to_keep,
|
584 |
+
)
|
585 |
+
|
586 |
+
return outputs
|
587 |
+
|
588 |
+
def prepare_inputs_for_generation(
|
589 |
+
self,
|
590 |
+
input_ids=None,
|
591 |
+
audio_features=None,
|
592 |
+
inputs_embeds=None,
|
593 |
+
attention_mask=None,
|
594 |
+
position_ids=None,
|
595 |
+
pixel_values=None,
|
596 |
+
aspect_ratio_ids=None,
|
597 |
+
aspect_ratio_mask=None,
|
598 |
+
cross_attention_mask=None,
|
599 |
+
past_key_values=None,
|
600 |
+
use_cache=False,
|
601 |
+
cache_position=None,
|
602 |
+
num_logits_to_keep=None,
|
603 |
+
**kwargs,
|
604 |
+
):
|
605 |
+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
606 |
+
|
607 |
+
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
608 |
+
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
609 |
+
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
610 |
+
if past_key_values is not None:
|
611 |
+
if inputs_embeds is not None: # Exception 1
|
612 |
+
input_ids = input_ids[:, -cache_position.shape[0] :]
|
613 |
+
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
614 |
+
input_ids = input_ids[:, cache_position]
|
615 |
+
|
616 |
+
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
|
617 |
+
if attention_mask is not None and position_ids is None:
|
618 |
+
# create position_ids on the fly for batch generation
|
619 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
620 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
621 |
+
if past_key_values:
|
622 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
623 |
+
|
624 |
+
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
625 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
626 |
+
|
627 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
628 |
+
if inputs_embeds is not None and cache_position[0] == 0:
|
629 |
+
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
630 |
+
else:
|
631 |
+
# The clone here is for the same reason as for `position_ids`.
|
632 |
+
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
|
633 |
+
|
634 |
+
if num_logits_to_keep is not None:
|
635 |
+
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
636 |
+
|
637 |
+
model_inputs.update(
|
638 |
+
{
|
639 |
+
"position_ids": position_ids,
|
640 |
+
"cache_position": cache_position,
|
641 |
+
"past_key_values": past_key_values,
|
642 |
+
"use_cache": use_cache,
|
643 |
+
"attention_mask": attention_mask,
|
644 |
+
"cross_attention_mask": cross_attention_mask,
|
645 |
+
}
|
646 |
+
)
|
647 |
+
|
648 |
+
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
|
649 |
+
# to compute image hidden states, otherwise they are cached within each cross attn layer
|
650 |
+
if cache_position[0] == 0:
|
651 |
+
model_inputs["pixel_values"] = pixel_values
|
652 |
+
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
|
653 |
+
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
|
654 |
+
|
655 |
+
return model_inputs
|
656 |
+
|
657 |
+
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
|
658 |
+
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
|
659 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
660 |
+
outputs=outputs,
|
661 |
+
model_kwargs=model_kwargs,
|
662 |
+
is_encoder_decoder=is_encoder_decoder,
|
663 |
+
**kwargs,
|
664 |
+
)
|
665 |
+
|
666 |
+
# add cross-attn mask for new token
|
667 |
+
if cross_attention_mask_prev is not None:
|
668 |
+
model_kwargs["cross_attention_mask"] = torch.cat(
|
669 |
+
[cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
|
670 |
+
)
|
671 |
+
return model_kwargs
|
672 |
+
|
673 |
+
#AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
|
674 |
+
#transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration
|
processing_mllama.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Processor class for Mllama."""
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import transformers
|
23 |
+
from transformers import AutoProcessor
|
24 |
+
from transformers.feature_extraction_utils import BatchFeature
|
25 |
+
from transformers.image_utils import ImageInput
|
26 |
+
from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
|
27 |
+
from transformers.tokenization_utils_base import (
|
28 |
+
PreTokenizedInput,
|
29 |
+
TextInput,
|
30 |
+
AudioInput,
|
31 |
+
)
|
32 |
+
|
33 |
+
# TODO: Can we do it that way or its better include as "Copied from ..."
|
34 |
+
from transformers.models.mllama.image_processing_mllama import make_list_of_images
|
35 |
+
from .audio_processing_mllama import build_audio_tokens
|
36 |
+
|
37 |
+
|
38 |
+
class MllamaImagesKwargs(ImagesKwargs, total=False):
|
39 |
+
max_image_tiles: Optional[int]
|
40 |
+
|
41 |
+
class MllamaProcessorKwargs(ProcessingKwargs, total=False):
|
42 |
+
images_kwargs: MllamaImagesKwargs
|
43 |
+
|
44 |
+
_defaults = {
|
45 |
+
"image_kwargs": {
|
46 |
+
"max_image_tiles": 4,
|
47 |
+
},
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def get_cross_attention_token_mask(input_ids: List[int], image_token_id: int) -> List[List[int]]:
|
52 |
+
"""
|
53 |
+
Generate a cross-attention token mask for image tokens in the input sequence.
|
54 |
+
|
55 |
+
This function identifies the positions of image tokens in the input sequence and creates
|
56 |
+
a mask that defines which subsequent tokens each image token should attend to.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
input_ids (List[int]): A list of token ids representing the input sequence.
|
60 |
+
image_token_id (int): The id of the token used to represent images in the sequence.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
List[List[int]]: A list of [start, end] pairs, where each pair represents the range
|
64 |
+
of tokens an image token should attend to.
|
65 |
+
|
66 |
+
Notes:
|
67 |
+
- If no image tokens are present, an empty list is returned.
|
68 |
+
- For a single image token, it attends to all subsequent tokens until the end of the sequence.
|
69 |
+
- For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
|
70 |
+
- Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
|
71 |
+
"""
|
72 |
+
|
73 |
+
image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
|
74 |
+
|
75 |
+
if len(image_token_locations) == 0:
|
76 |
+
return []
|
77 |
+
|
78 |
+
# only one image present, unmask until end of sequence
|
79 |
+
if len(image_token_locations) == 1:
|
80 |
+
return [[image_token_locations[0], -1]]
|
81 |
+
|
82 |
+
vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
|
83 |
+
|
84 |
+
# last image will attend to all subsequent text
|
85 |
+
vision_masks.append([image_token_locations[-1], len(input_ids)])
|
86 |
+
|
87 |
+
# if there are two or more consecutive vision tokens,
|
88 |
+
# they should all attend to all subsequent
|
89 |
+
# text present
|
90 |
+
last_mask_end = vision_masks[-1][1]
|
91 |
+
for vision_mask in vision_masks[::-1]:
|
92 |
+
if vision_mask[0] == vision_mask[1] - 1:
|
93 |
+
vision_mask[1] = last_mask_end
|
94 |
+
last_mask_end = vision_mask[1]
|
95 |
+
|
96 |
+
return vision_masks
|
97 |
+
|
98 |
+
|
99 |
+
def convert_sparse_cross_attention_mask_to_dense(
|
100 |
+
cross_attention_token_mask: List[List[List[int]]],
|
101 |
+
num_tiles: List[List[int]],
|
102 |
+
max_num_tiles: int,
|
103 |
+
length: int,
|
104 |
+
) -> np.ndarray:
|
105 |
+
"""
|
106 |
+
Convert the cross attention mask indices to a cross attention mask 4D array.
|
107 |
+
|
108 |
+
This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
|
109 |
+
The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
cross_attention_token_mask (List[List[List[int]]]): A nested list structure where:
|
113 |
+
- The outer list represents the batch dimension.
|
114 |
+
- The middle list represents different images within each batch item.
|
115 |
+
- The inner list contains pairs of integers [start, end] representing token ranges for each image.
|
116 |
+
num_tiles (List[List[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
|
117 |
+
max_num_tiles (int): The maximum possible number of tiles.
|
118 |
+
length (int): The total sequence length of the input.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
|
122 |
+
The array contains `1` where attention is allowed and `0` where it is not.
|
123 |
+
|
124 |
+
Note:
|
125 |
+
- Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
|
126 |
+
"""
|
127 |
+
|
128 |
+
batch_size = len(cross_attention_token_mask)
|
129 |
+
max_num_images = max([len(masks) for masks in cross_attention_token_mask])
|
130 |
+
|
131 |
+
cross_attention_mask = np.zeros(
|
132 |
+
shape=(batch_size, length, max_num_images, max_num_tiles),
|
133 |
+
dtype=np.int64,
|
134 |
+
)
|
135 |
+
|
136 |
+
for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
|
137 |
+
for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
|
138 |
+
if len(locations) == 2:
|
139 |
+
start, end = locations
|
140 |
+
end = min(end, length)
|
141 |
+
if end == -1:
|
142 |
+
end = length
|
143 |
+
cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
|
144 |
+
return cross_attention_mask
|
145 |
+
|
146 |
+
|
147 |
+
def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
|
148 |
+
"""
|
149 |
+
Builds a string from the input prompt by adding `bos_token` if not already present.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
prompt (`str`):
|
153 |
+
The input prompt string.
|
154 |
+
bos_token (`str`):
|
155 |
+
The beginning of sentence token to be added.
|
156 |
+
image_token (`str`):
|
157 |
+
The image token used to identify the start of an image sequence.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
str: The modified prompt string with the `bos_token` added if necessary.
|
161 |
+
|
162 |
+
Examples:
|
163 |
+
>>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
|
164 |
+
'<begin_of_text>Hello world'
|
165 |
+
|
166 |
+
>>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
|
167 |
+
'<|image|><begin_of_text>Hello world'
|
168 |
+
|
169 |
+
>>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
|
170 |
+
'<begin_of_text>Hello world'
|
171 |
+
"""
|
172 |
+
|
173 |
+
if bos_token in prompt:
|
174 |
+
return prompt
|
175 |
+
|
176 |
+
num_image_tokens_on_start = 0
|
177 |
+
while prompt.startswith(image_token):
|
178 |
+
prompt = prompt[len(image_token) :]
|
179 |
+
num_image_tokens_on_start += 1
|
180 |
+
|
181 |
+
return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
|
182 |
+
|
183 |
+
|
184 |
+
class MllamaProcessor(ProcessorMixin):
|
185 |
+
r"""
|
186 |
+
Constructs a Mllama processor which wraps [`MllamaImageProcessor`] and
|
187 |
+
[`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and
|
188 |
+
tokenizer functionalities. See the [`~MllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
|
189 |
+
information.
|
190 |
+
The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
|
191 |
+
```python
|
192 |
+
from transformers import MllamaProcessor
|
193 |
+
from PIL import Image
|
194 |
+
|
195 |
+
processor = MllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision")
|
196 |
+
|
197 |
+
processor(
|
198 |
+
images=your_pil_image,
|
199 |
+
text=["<|image|>If I had to write a haiku for this one"],
|
200 |
+
images_kwargs = {"size": {"height": 448, "width": 448}},
|
201 |
+
text_kwargs = {"padding": "right"},
|
202 |
+
common_kwargs = {"return_tensors": "pt"},
|
203 |
+
)
|
204 |
+
```
|
205 |
+
|
206 |
+
Args:
|
207 |
+
image_processor ([`MllamaImageProcessor`]):
|
208 |
+
The image processor is a required input.
|
209 |
+
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
|
210 |
+
The tokenizer is a required input.
|
211 |
+
|
212 |
+
"""
|
213 |
+
|
214 |
+
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
215 |
+
image_processor_class = "MllamaImageProcessor"
|
216 |
+
audio_processor_class = "MllamaAudioFeatureExtractor"
|
217 |
+
tokenizer_class = "PreTrainedTokenizerFast"
|
218 |
+
|
219 |
+
def __init__(self, image_processor, audio_processor, tokenizer):
|
220 |
+
self.image_token = "<|image|>"
|
221 |
+
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
222 |
+
self.audio_token = "<|audio|>"
|
223 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
|
224 |
+
self.python_token = "<|python_tag|>"
|
225 |
+
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
|
226 |
+
self.bos_token = tokenizer.bos_token
|
227 |
+
self.chat_template = tokenizer.chat_template
|
228 |
+
super().__init__(image_processor, audio_processor, tokenizer)
|
229 |
+
self.tokenizer.add_tokens([f"<|audio_{i}|>" for i in range(1, 50)])
|
230 |
+
|
231 |
+
|
232 |
+
def __call__(
|
233 |
+
self,
|
234 |
+
images: Optional[ImageInput] = None,
|
235 |
+
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
236 |
+
audio: Optional[Union[AudioInput, List[AudioInput]]] = None,
|
237 |
+
videos=None,
|
238 |
+
**kwargs: Unpack[MllamaProcessorKwargs],
|
239 |
+
) -> BatchFeature:
|
240 |
+
"""
|
241 |
+
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
|
242 |
+
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
|
243 |
+
the text. To prepare the image(s), this method forwards the `images` arguments to
|
244 |
+
MllamaImageProcessor's [`~MllamaImageProcessor.__call__`] if `images` is not `None`. Please refer
|
245 |
+
to the docstring of the above two methods for more information.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
249 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
250 |
+
tensor. Both channels-first and channels-last formats are supported.
|
251 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
252 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
253 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
254 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
255 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
256 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
257 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
258 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
259 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
260 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
261 |
+
Returns:
|
262 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
263 |
+
|
264 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
265 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
266 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
267 |
+
`None`).
|
268 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
269 |
+
- **audio_features** -- Audio features extracted using SeamlessM4TFeatureExtractor. Returned when `audio` is not `None`.
|
270 |
+
TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
|
271 |
+
"""
|
272 |
+
if text is None:
|
273 |
+
raise ValueError("You must specify text.")
|
274 |
+
|
275 |
+
output_kwargs = self._merge_kwargs(
|
276 |
+
MllamaProcessorKwargs,
|
277 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
278 |
+
**kwargs,
|
279 |
+
)
|
280 |
+
|
281 |
+
text_kwargs = output_kwargs["text_kwargs"]
|
282 |
+
images_kwargs = output_kwargs["images_kwargs"]
|
283 |
+
common_kwargs = output_kwargs["common_kwargs"]
|
284 |
+
|
285 |
+
data = {}
|
286 |
+
|
287 |
+
if audio is not None:
|
288 |
+
audio_features = self.audio_processor(audio)
|
289 |
+
data.update(audio_features)
|
290 |
+
|
291 |
+
if isinstance(text, str):
|
292 |
+
text = [text]
|
293 |
+
elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
|
294 |
+
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
|
295 |
+
n_images_in_text = [t.count(self.image_token) for t in text]
|
296 |
+
text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
|
297 |
+
_ = text_kwargs.pop("padding_side", None) # hack until padding-side is an accepted kwarg by tokenizers
|
298 |
+
|
299 |
+
if audio is not None:
|
300 |
+
text = build_audio_tokens(text, audio_features, self.audio_token)
|
301 |
+
|
302 |
+
encoding = self.tokenizer(text, add_special_tokens=False, **text_kwargs)
|
303 |
+
if audio is not None:
|
304 |
+
beg_audio_id = self.tokenizer.convert_tokens_to_ids("<|audio_1|>")
|
305 |
+
idx = torch.where(encoding['input_ids'] >= beg_audio_id)
|
306 |
+
encoding['input_ids'][idx] = beg_audio_id - encoding['input_ids'][idx] - 1
|
307 |
+
data.update(encoding)
|
308 |
+
|
309 |
+
n_images_in_images = [0]
|
310 |
+
if images is not None:
|
311 |
+
images = make_list_of_images(images)
|
312 |
+
n_images_in_images = [len(sample) for sample in images]
|
313 |
+
|
314 |
+
if text is not None:
|
315 |
+
if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
|
316 |
+
batch_img == 0 for batch_img in n_images_in_text
|
317 |
+
):
|
318 |
+
raise ValueError(
|
319 |
+
"If a batch of text is provided, there should be either no images or at least one image per sample"
|
320 |
+
)
|
321 |
+
if sum(n_images_in_images) != sum(n_images_in_text):
|
322 |
+
if images is None:
|
323 |
+
raise ValueError("No image were provided, but there are image tokens in the prompt")
|
324 |
+
else:
|
325 |
+
raise ValueError(
|
326 |
+
f"The number of image token ({sum(n_images_in_text)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
|
327 |
+
)
|
328 |
+
|
329 |
+
if images is not None:
|
330 |
+
image_features = self.image_processor(images, **images_kwargs)
|
331 |
+
num_tiles = image_features.pop("num_tiles")
|
332 |
+
data.update(image_features)
|
333 |
+
|
334 |
+
# Create cross attention mask
|
335 |
+
if images is not None and text is not None:
|
336 |
+
cross_attention_token_mask = [
|
337 |
+
get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
|
338 |
+
]
|
339 |
+
cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
|
340 |
+
cross_attention_token_mask,
|
341 |
+
num_tiles=num_tiles,
|
342 |
+
max_num_tiles=self.image_processor.max_image_tiles,
|
343 |
+
length=max(len(input_ids) for input_ids in encoding["input_ids"]),
|
344 |
+
)
|
345 |
+
data["cross_attention_mask"] = cross_attention_mask
|
346 |
+
|
347 |
+
return_tensors = common_kwargs.pop("return_tensors", None)
|
348 |
+
batch_feature = BatchFeature(data=data, tensor_type=return_tensors)
|
349 |
+
|
350 |
+
return batch_feature
|
351 |
+
|
352 |
+
def batch_decode(self, *args, **kwargs):
|
353 |
+
"""
|
354 |
+
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
355 |
+
refer to the docstring of this method for more information.
|
356 |
+
"""
|
357 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
358 |
+
|
359 |
+
def decode(self, *args, **kwargs):
|
360 |
+
"""
|
361 |
+
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
362 |
+
the docstring of this method for more information.
|
363 |
+
"""
|
364 |
+
return self.tokenizer.decode(*args, **kwargs)
|
365 |
+
|
366 |
+
@property
|
367 |
+
def model_input_names(self):
|
368 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
369 |
+
image_processor_input_names = self.image_processor.model_input_names
|
370 |
+
audio_processor_input_names = self.audio_processor.model_input_names
|
371 |
+
return list(tokenizer_input_names +
|
372 |
+
image_processor_input_names +
|
373 |
+
["cross_attention_mask"] +
|
374 |
+
audio_processor_input_names)
|
375 |
+
|
376 |
+
AutoProcessor.register("MllamaProcessor", MllamaProcessor)
|
377 |
+
transformers.MllamaProcessor = MllamaProcessor
|