AlexHung29629 commited on
Commit
3889df6
1 Parent(s): 368067b

Upload 5 files

Browse files
audio_processing_mllama.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Dict, List, Optional, Union
3
+ import numpy as np
4
+ import transformers
5
+ from transformers.tokenization_utils_base import AudioInput
6
+ from transformers.models.seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor
7
+ from transformers.utils import TensorType
8
+ from transformers.feature_extraction_utils import BatchFeature
9
+ from transformers import AutoFeatureExtractor
10
+
11
+
12
+ def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np.ndarray]]], audio_token="<|audio|>") -> Dict:
13
+ if not isinstance(audio_features, list):
14
+ audio_features = audio_features['audio_features']
15
+ bs = audio_features.shape[0]
16
+ for i in range(bs):
17
+ for j in range(len(audio_features[i])):
18
+ tgt_token = f"<|audio_{j+1}|>" * get_num_embeddings(audio_features[i][j].shape[0])
19
+ text[i] = text[i].replace(audio_token, tgt_token, 1)
20
+ return text
21
+
22
+ def get_num_embeddings(num_framses, adapter_kernel_size=7, adapter_stride=4) -> int:
23
+ return math.ceil((num_framses - adapter_kernel_size) / adapter_stride) + 1 + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
24
+
25
+
26
+ class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
27
+
28
+ def __call__(
29
+ self,
30
+ batch_audio_clips: List[List[AudioInput]],
31
+ return_tensors: Optional[Union[str, TensorType]] = None,
32
+ ) -> BatchFeature:
33
+ audio_features = [[ super(MllamaAudioFeatureExtractor, self).__call__(audio_j, sampling_rate=16000, return_attention_mask=False)['input_features'][0] for audio_j in audio_i ] for audio_i in batch_audio_clips ]
34
+ packed_audio_features = self.pack_audio_clips(audio_features)
35
+
36
+ encoded_audio_inputs = BatchFeature(
37
+ data={
38
+ "audio_features": packed_audio_features,
39
+ },
40
+ tensor_type=return_tensors,
41
+ )
42
+
43
+ return encoded_audio_inputs
44
+
45
+ def pack_audio_clips(self, batch_audio_clips: List[List[np.ndarray]]) -> np.ndarray:
46
+ assert batch_audio_clips[0][0].ndim == 2 # sequence length x feature dimension
47
+ # Determine output shape: (batch_size, max_num_clips, max_frames, feature_dim)
48
+ batch_size = len(batch_audio_clips)
49
+ max_num_clips = max([len(clips) for clips in batch_audio_clips])
50
+ max_frames = max([clip.shape[0] for clips in batch_audio_clips for clip in clips])
51
+ feature_dim = batch_audio_clips[0][0].shape[1]
52
+
53
+ stacked_audio_clips = np.zeros((batch_size, max_num_clips, max_frames, feature_dim), dtype=np.float32)
54
+ for i, clips in enumerate(batch_audio_clips):
55
+ for j, clip in enumerate(clips):
56
+ stacked_audio_clips[i, j, :clip.shape[0], :] = clip
57
+
58
+ return stacked_audio_clips
59
+
60
+ AutoFeatureExtractor.register("MllamaAudioFeatureExtractor", MllamaAudioFeatureExtractor)
61
+ transformers.MllamaAudioFeatureExtractor = MllamaAudioFeatureExtractor
configuration_llama3.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Mllama model configuration"""
15
+
16
+ import os
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import transformers
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.modeling_rope_utils import rope_config_validation
22
+ from transformers.utils import logging
23
+ from transformers import Wav2Vec2BertConfig, AutoConfig
24
+ from transformers.models.mllama.configuration_mllama import MllamaVisionConfig, MllamaTextConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class Llama3Config(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`MllamaForConditionalGeneration`]. It is used to instantiate an
32
+ Mllama model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the Mllama-9B.
34
+
35
+ e.g. [meta-llama/Llama-3.2-11B-Vision](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision)
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+ Args:
41
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaVisionConfig`):
42
+ The config object or dictionary of the vision backbone.
43
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `MllamaTextConfig`):
44
+ The config object or dictionary of the text backbone.
45
+ image_token_index (`int`, *optional*, defaults to 128256):
46
+ The image token index to encode the image prompt.
47
+
48
+ Example:
49
+
50
+ ```python
51
+ >>> from transformers import MllamaForConditionalGeneration, MllamaConfig, MllamaVisionConfig, MllamaTextConfig
52
+
53
+ >>> # Initializing a CLIP-vision config
54
+ >>> vision_config = MllamaVisionConfig()
55
+
56
+ >>> # Initializing a Llama config
57
+ >>> text_config = MllamaTextConfig()
58
+
59
+ >>> # Initializing a mllama-11b style configuration
60
+ >>> configuration = MllamaConfig(vision_config, text_config)
61
+
62
+ >>> # Initializing a model from the mllama-11b style configuration
63
+ >>> model = MllamaForConditionalGeneration(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "llama3"
70
+ is_composition = True
71
+
72
+ def __init__(
73
+ self,
74
+ vision_config=None,
75
+ text_config=None,
76
+ audio_config=None,
77
+ image_token_index=128256,
78
+ audio_token_index=128257,
79
+ **kwargs,
80
+ ):
81
+ if vision_config is None:
82
+ self.vision_config = MllamaVisionConfig()
83
+ logger.info("vision_config is None, using default mllama vision config")
84
+ elif isinstance(vision_config, dict):
85
+ self.vision_config = MllamaVisionConfig(**vision_config)
86
+ elif isinstance(vision_config, MllamaVisionConfig):
87
+ self.vision_config = vision_config
88
+
89
+ self.image_token_index = image_token_index
90
+
91
+ if audio_config is None:
92
+ self.audio_config = Wav2Vec2BertConfig()
93
+ logger.info("audio_config is None, using default mllama audio config")
94
+ elif isinstance(audio_config, dict):
95
+ self.audio_config = Wav2Vec2BertConfig(**audio_config)
96
+ elif isinstance(audio_config, Wav2Vec2BertConfig):
97
+ self.audio_config = audio_config
98
+
99
+ self.audio_token_index = audio_token_index
100
+
101
+ if text_config is None:
102
+ self.text_config = MllamaTextConfig()
103
+ logger.info("text_config is None, using default mllama text config")
104
+ elif isinstance(text_config, dict):
105
+ self.text_config = MllamaTextConfig(**text_config)
106
+ elif isinstance(text_config, MllamaTextConfig):
107
+ self.text_config = text_config
108
+
109
+ super().__init__(**kwargs)
110
+
111
+ AutoConfig.register("llama3", Llama3Config)
112
+ transformers.Llama3Config = Llama3Config
mllama_audio_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch
3
+ from torch import nn
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+ from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, Wav2Vec2BertPreTrainedModel
6
+ from transformers.models.mllama.configuration_mllama import MllamaTextConfig
7
+
8
+
9
+ class Llama3Embedding(Wav2Vec2BertPreTrainedModel):
10
+ base_model_prefix = "audio_model"
11
+ def __init__(self, config: Wav2Vec2BertConfig, text_config: MllamaTextConfig):
12
+ super().__init__(config)
13
+ assert config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
14
+ assert config.output_hidden_size == text_config.hidden_size
15
+ self.text_embeddings = nn.Embedding(text_config.vocab_size, text_config.hidden_size, text_config.pad_token_id)
16
+ self.audio_embedding = Wav2Vec2BertModel(config)
17
+ #assert self.text_embeddings.weight.size(-1) == text_config.hidden_size, f"{self.text_embeddings.weight}, {text_config.hidden_size=}, {text_config.vocab_size=}"
18
+ self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
19
+ self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.output_hidden_size)), requires_grad=True)
20
+ self.text_config = text_config
21
+
22
+ def _init_weights(self, module):
23
+ std = self.text_config.initializer_range
24
+ """Initialize the weights"""
25
+ if isinstance(module, Wav2Vec2BertSelfAttention):
26
+ if hasattr(module, "pos_bias_u"):
27
+ nn.init.xavier_uniform_(module.pos_bias_u)
28
+ if hasattr(module, "pos_bias_v"):
29
+ nn.init.xavier_uniform_(module.pos_bias_v)
30
+ elif isinstance(module, Wav2Vec2BertFeatureProjection):
31
+ k = math.sqrt(1 / module.projection.in_features)
32
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
33
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
34
+ elif isinstance(module, nn.Linear):
35
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
36
+
37
+ if module.bias is not None:
38
+ module.bias.data.zero_()
39
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
40
+ module.bias.data.zero_()
41
+ module.weight.data.fill_(1.0)
42
+ elif isinstance(module, nn.Conv1d):
43
+ nn.init.kaiming_normal_(module.weight)
44
+
45
+ if module.bias is not None:
46
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
47
+ nn.init.uniform_(module.bias, a=-k, b=k)
48
+ elif isinstance(module, nn.Embedding):
49
+ module.weight.data.normal_(mean=0.0, std=std)
50
+ if module.padding_idx is not None:
51
+ module.weight.data[module.padding_idx].zero_()
52
+ elif isinstance(module, nn.Parameter):
53
+ module.data.normal_(mean=0.0, std=std)
54
+
55
+ def forward(
56
+ self,
57
+ input_ids: torch.LongTensor = None,
58
+ audio_features: Optional[torch.Tensor] = None,
59
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
60
+ input_embeddings = self.text_embeddings(torch.clamp(input_ids, min=0))
61
+ if audio_features is None:
62
+ return input_embeddings
63
+ bs, max_num_img, l, d = audio_features.shape
64
+ audio_embeddings = self.audio_embedding(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
65
+ audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
66
+
67
+ for i in range(bs):
68
+ for j in range(max_num_img):
69
+ audio_id = -1 - j
70
+ if torch.any(input_ids[i] == audio_id):
71
+ positions = torch.nonzero(input_ids[i] == audio_id, as_tuple=True)
72
+ seq_len = input_embeddings[i][positions].shape[0] - 2
73
+ input_embeddings[i] = input_embeddings[i].index_put(positions, torch.concat([self.start_of_audio, audio_embeddings[i, j, :seq_len, :], self.end_of_audio]), accumulate=False)
74
+ return input_embeddings
modeling_llama3.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+ import transformers
9
+ from transformers import MllamaPreTrainedModel, MllamaVisionModel, MllamaForCausalLM, Wav2Vec2BertConfig, AutoModel, AutoModelForCausalLM
10
+ from transformers.cache_utils import Cache, StaticCache
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
13
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
14
+ from transformers.utils import logging
15
+ from transformers.models.mllama.modeling_mllama import _prepare_cross_attention_mask, MllamaCrossAttentionDecoderLayer, MllamaSelfAttentionDecoderLayer, MllamaTextRMSNorm, MllamaRotaryEmbedding
16
+ from transformers.models.mllama.configuration_mllama import MllamaTextConfig
17
+ from .configuration_llama3 import Llama3Config
18
+ from .mllama_audio_model import Llama3Embedding
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+ class Llama3PreTrainedModel(MllamaPreTrainedModel):
24
+ config_class = Llama3Config
25
+ base_model_prefix = "model"
26
+
27
+ class Llama3TextModel(MllamaPreTrainedModel):
28
+ config_class = MllamaTextConfig
29
+ base_model_prefix = "language_model.model"
30
+
31
+ def __init__(self, config: MllamaTextConfig):
32
+ super().__init__(config)
33
+ self.padding_idx = config.pad_token_id
34
+ self.vocab_size = config.vocab_size
35
+ #self.embed_tokens = Llama3Embedding(audio_config, config)
36
+ self.cross_attention_layers = config.cross_attention_layers
37
+
38
+ layers = []
39
+ for layer_idx in range(config.num_hidden_layers):
40
+ if layer_idx in self.cross_attention_layers:
41
+ layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx))
42
+ else:
43
+ layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx))
44
+
45
+ self.layers = nn.ModuleList(layers)
46
+ self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
47
+ self.rotary_emb = MllamaRotaryEmbedding(config=config)
48
+ self.gradient_checkpointing = False
49
+ self.post_init()
50
+
51
+ def get_input_embeddings(self):
52
+ #return self.embed_tokens.text_embeddings
53
+ return None
54
+
55
+ def set_input_embeddings(self, value):
56
+ #self.embed_tokens.text_embeddings = value
57
+ pass
58
+
59
+ def forward(
60
+ self,
61
+ #input_ids: Optional[torch.LongTensor] = None,
62
+ #audio_features: Optional[torch.Tensor] = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ cross_attention_states: Optional[torch.FloatTensor] = None,
66
+ cross_attention_mask: Optional[torch.Tensor] = None,
67
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
68
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
69
+ inputs_embeds: Optional[torch.FloatTensor] = None,
70
+ use_cache: Optional[bool] = None,
71
+ output_attentions: Optional[bool] = None,
72
+ output_hidden_states: Optional[bool] = None,
73
+ return_dict: Optional[bool] = None,
74
+ cache_position: Optional[torch.LongTensor] = None,
75
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
76
+ """
77
+
78
+ Returns:
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import AutoProcessor, MllamaTextModel
84
+
85
+ >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
86
+ >>> model = MllamaTextModel.from_pretrained(checkpoint)
87
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
88
+
89
+ >>> text = "<|image|>If I had to write a haiku for this one"
90
+ >>> inputs = processor(text=text, return_tensors="pt")
91
+
92
+ >>> output = model(**inputs)
93
+
94
+ >>> print(output.last_hidden_state.shape)
95
+ torch.Size([1, 13, 4096])
96
+ ```
97
+ """
98
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
99
+ output_hidden_states = (
100
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
101
+ )
102
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+
105
+ #if (input_ids is None) ^ (inputs_embeds is not None):
106
+ # raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
107
+
108
+ if self.gradient_checkpointing and self.training and use_cache:
109
+ logger.warning_once(
110
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
111
+ )
112
+ use_cache = False
113
+
114
+ #if inputs_embeds is None:
115
+ # inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
116
+
117
+
118
+ hidden_states = inputs_embeds
119
+
120
+ if cache_position is None:
121
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
122
+ cache_position = torch.arange(
123
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
124
+ )
125
+ if position_ids is None:
126
+ position_ids = cache_position.unsqueeze(0)
127
+
128
+ causal_mask = self._update_causal_mask(
129
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
130
+ )
131
+
132
+ # create position embeddings to be shared across the decoder layers
133
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
134
+
135
+ # decoder layers
136
+ all_hidden_states = () if output_hidden_states else None
137
+ all_self_attns = () if output_attentions else None
138
+ next_decoder_cache = None
139
+
140
+ for idx, decoder_layer in enumerate(self.layers):
141
+ if output_hidden_states:
142
+ all_hidden_states += (hidden_states,)
143
+
144
+ # For text-only path we should skip cross attention layers.
145
+ # Let's check if the layer is cross attention layer and if we have cross attention states
146
+ # or cached cross attention states.
147
+ is_cross_attention_layer = idx in self.cross_attention_layers
148
+ is_cross_attention_cache_empty = past_key_values is None or (
149
+ past_key_values is not None and past_key_values.get_seq_length(idx) == 0
150
+ )
151
+
152
+ if is_cross_attention_layer and cross_attention_states is None and is_cross_attention_cache_empty:
153
+ continue
154
+
155
+ if self.gradient_checkpointing and self.training:
156
+ layer_outputs = self._gradient_checkpointing_func(
157
+ decoder_layer.__call__,
158
+ hidden_states,
159
+ cross_attention_states,
160
+ cross_attention_mask,
161
+ causal_mask,
162
+ full_text_row_masked_out_mask,
163
+ position_ids,
164
+ past_key_values,
165
+ output_attentions,
166
+ use_cache,
167
+ cache_position,
168
+ position_embeddings,
169
+ )
170
+ else:
171
+ layer_outputs = decoder_layer(
172
+ hidden_states,
173
+ cross_attention_states=cross_attention_states,
174
+ cross_attention_mask=cross_attention_mask,
175
+ attention_mask=causal_mask,
176
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
177
+ position_ids=position_ids,
178
+ past_key_value=past_key_values,
179
+ output_attentions=output_attentions,
180
+ use_cache=use_cache,
181
+ cache_position=cache_position,
182
+ position_embeddings=position_embeddings,
183
+ )
184
+
185
+ hidden_states = layer_outputs[0]
186
+
187
+ if use_cache:
188
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
189
+
190
+ if output_attentions:
191
+ all_self_attns += (layer_outputs[1],)
192
+
193
+ hidden_states = self.norm(hidden_states)
194
+
195
+ # add hidden states from the last decoder layer
196
+ if output_hidden_states:
197
+ all_hidden_states += (hidden_states,)
198
+
199
+ next_cache = next_decoder_cache if use_cache else None
200
+
201
+ if not return_dict:
202
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
203
+ return BaseModelOutputWithPast(
204
+ last_hidden_state=hidden_states,
205
+ past_key_values=next_cache,
206
+ hidden_states=all_hidden_states,
207
+ attentions=all_self_attns,
208
+ )
209
+
210
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
211
+ def _update_causal_mask(
212
+ self,
213
+ attention_mask: torch.Tensor,
214
+ input_tensor: torch.Tensor,
215
+ cache_position: torch.Tensor,
216
+ past_key_values: Cache,
217
+ output_attentions: bool,
218
+ ):
219
+ if self.config._attn_implementation == "flash_attention_2":
220
+ if attention_mask is not None and 0.0 in attention_mask:
221
+ return attention_mask
222
+ return None
223
+
224
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
225
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
226
+ # to infer the attention mask.
227
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
228
+ using_static_cache = isinstance(past_key_values, StaticCache)
229
+
230
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
231
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
232
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
233
+ attention_mask,
234
+ inputs_embeds=input_tensor,
235
+ past_key_values_length=past_seen_tokens,
236
+ is_training=self.training,
237
+ ):
238
+ return None
239
+
240
+ dtype, device = input_tensor.dtype, input_tensor.device
241
+ sequence_length = input_tensor.shape[1]
242
+ if using_static_cache:
243
+ target_length = past_key_values.get_max_cache_shape()
244
+ else:
245
+ target_length = (
246
+ attention_mask.shape[-1]
247
+ if isinstance(attention_mask, torch.Tensor)
248
+ else past_seen_tokens + sequence_length + 1
249
+ )
250
+
251
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
252
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
253
+ attention_mask,
254
+ sequence_length=sequence_length,
255
+ target_length=target_length,
256
+ dtype=dtype,
257
+ device=device,
258
+ cache_position=cache_position,
259
+ batch_size=input_tensor.shape[0],
260
+ )
261
+
262
+ if (
263
+ self.config._attn_implementation == "sdpa"
264
+ and attention_mask is not None
265
+ and attention_mask.device.type == "cuda"
266
+ and not output_attentions
267
+ ):
268
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
269
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
270
+ # Details: https://github.com/pytorch/pytorch/issues/110213
271
+ min_dtype = torch.finfo(dtype).min
272
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
273
+
274
+ return causal_mask
275
+
276
+ @staticmethod
277
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
278
+ def _prepare_4d_causal_attention_mask_with_cache_position(
279
+ attention_mask: torch.Tensor,
280
+ sequence_length: int,
281
+ target_length: int,
282
+ dtype: torch.dtype,
283
+ device: torch.device,
284
+ cache_position: torch.Tensor,
285
+ batch_size: int,
286
+ **kwargs,
287
+ ):
288
+ if attention_mask is not None and attention_mask.dim() == 4:
289
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
290
+ causal_mask = attention_mask
291
+ else:
292
+ min_dtype = torch.finfo(dtype).min
293
+ causal_mask = torch.full(
294
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
295
+ )
296
+ if sequence_length != 1:
297
+ causal_mask = torch.triu(causal_mask, diagonal=1)
298
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
299
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
300
+ if attention_mask is not None:
301
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
302
+ mask_length = attention_mask.shape[-1]
303
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
304
+ padding_mask = padding_mask == 0
305
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
306
+ padding_mask, min_dtype
307
+ )
308
+
309
+ return causal_mask
310
+
311
+ class Llama3ForCausalLM(MllamaPreTrainedModel, GenerationMixin):
312
+ config_class = MllamaTextConfig
313
+ base_model_prefix = "model"
314
+ #_tied_weights_keys = ["lm_head.weight"]
315
+
316
+ def __init__(self, config: MllamaTextConfig):
317
+ super().__init__(config.get_text_config())
318
+ self.text_config = config.get_text_config()
319
+ self.vocab_size = self.text_config.vocab_size
320
+ self.model = Llama3TextModel._from_config(self.text_config, attn_implementation=config._attn_implementation)
321
+ self.lm_head = nn.Linear(self.text_config.hidden_size, self.vocab_size, bias=False)
322
+
323
+ self.post_init()
324
+
325
+ def get_input_embeddings(self):
326
+ #return self.model.embed_tokens.text_embeddings
327
+ return None
328
+
329
+ def set_input_embeddings(self, value):
330
+ #self.model.embed_tokens.text_embeddings = value
331
+ pass
332
+
333
+ def get_output_embeddings(self):
334
+ return self.lm_head
335
+
336
+ def set_output_embeddings(self, new_embeddings):
337
+ self.lm_head = new_embeddings
338
+
339
+ def set_decoder(self, decoder):
340
+ self.model = decoder
341
+
342
+ def get_decoder(self):
343
+ return self.model
344
+
345
+ def forward(
346
+ self,
347
+ #input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ position_ids: Optional[torch.LongTensor] = None,
350
+ cross_attention_states: Optional[torch.LongTensor] = None,
351
+ cross_attention_mask: Optional[torch.LongTensor] = None,
352
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
353
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
354
+ inputs_embeds: Optional[torch.FloatTensor] = None,
355
+ labels: Optional[torch.LongTensor] = None,
356
+ use_cache: Optional[bool] = None,
357
+ output_attentions: Optional[bool] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ cache_position: Optional[torch.LongTensor] = None,
361
+ num_logits_to_keep: int = 0,
362
+ **loss_kwargs,
363
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
364
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
365
+ output_hidden_states = (
366
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
367
+ )
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+
370
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
371
+ outputs = self.model(
372
+ #input_ids=input_ids,
373
+ cross_attention_states=cross_attention_states,
374
+ attention_mask=attention_mask,
375
+ position_ids=position_ids,
376
+ cross_attention_mask=cross_attention_mask,
377
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
378
+ past_key_values=past_key_values,
379
+ inputs_embeds=inputs_embeds,
380
+ use_cache=use_cache,
381
+ output_attentions=output_attentions,
382
+ output_hidden_states=output_hidden_states,
383
+ return_dict=return_dict,
384
+ cache_position=cache_position,
385
+ )
386
+
387
+ hidden_states = outputs[0]
388
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
389
+
390
+ loss = None
391
+ if labels is not None:
392
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
393
+
394
+ if not return_dict:
395
+ output = (logits,) + outputs[1:]
396
+ return (loss,) + output if loss is not None else output
397
+
398
+ return CausalLMOutputWithPast(
399
+ loss=loss,
400
+ logits=logits,
401
+ past_key_values=outputs.past_key_values,
402
+ hidden_states=outputs.hidden_states,
403
+ attentions=outputs.attentions,
404
+ )
405
+
406
+ #AutoModelForCausalLM.register(MllamaTextConfig, Llama3ForCausalLM)
407
+ #transformers.Llama3ForCausalLM = Llama3ForCausalLM
408
+
409
+ class Llama3ForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
410
+ config_class = Llama3Config
411
+ base_model_prefix = "model"
412
+ _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
413
+
414
+ def __init__(self, config: Llama3Config):
415
+ super().__init__(config)
416
+ self.vocab_size = config.text_config.vocab_size
417
+ self.hidden_size = config.text_config.hidden_size
418
+ self.max_num_tiles = config.vision_config.max_num_tiles
419
+ self.vision_output_dim = config.vision_config.vision_output_dim
420
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
421
+
422
+ self.vision_model = MllamaVisionModel._from_config(config.vision_config)
423
+ self.language_model = MllamaForCausalLM._from_config(config.text_config)
424
+ self.language_model.get_input_embeddings().weight.required_grad = False
425
+ self.embed_tokens = Llama3Embedding(config.audio_config, config.text_config)
426
+ self.multi_modal_projector = nn.Linear(
427
+ config.vision_config.vision_output_dim,
428
+ config.text_config.hidden_size,
429
+ bias=True,
430
+ )
431
+ self.post_init()
432
+
433
+ def get_input_embeddings(self):
434
+ return self.language_model.get_input_embeddings()
435
+
436
+ def set_input_embeddings(self, value):
437
+ self.language_model.set_input_embeddings(value)
438
+
439
+ def get_output_embeddings(self):
440
+ return self.language_model.get_output_embeddings()
441
+
442
+ def set_output_embeddings(self, new_embeddings):
443
+ self.language_model.set_output_embeddings(new_embeddings)
444
+
445
+ def set_decoder(self, decoder):
446
+ self.language_model.set_decoder(decoder)
447
+
448
+ def get_decoder(self):
449
+ return self.language_model.get_decoder()
450
+
451
+ def tie_weights(self):
452
+ return self.language_model.tie_weights()
453
+
454
+ def forward(
455
+ self,
456
+ input_ids: Optional[torch.LongTensor] = None,
457
+ audio_features: Optional[torch.FloatTensor] = None,
458
+ pixel_values: Optional[torch.FloatTensor] = None,
459
+ aspect_ratio_mask: Optional[torch.Tensor] = None,
460
+ aspect_ratio_ids: Optional[torch.Tensor] = None,
461
+ attention_mask: Optional[torch.Tensor] = None,
462
+ cross_attention_mask: Optional[torch.Tensor] = None,
463
+ cross_attention_states: Optional[torch.Tensor] = None,
464
+ position_ids: Optional[torch.LongTensor] = None,
465
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
466
+ inputs_embeds: Optional[torch.FloatTensor] = None,
467
+ labels: Optional[torch.LongTensor] = None,
468
+ use_cache: Optional[bool] = None,
469
+ output_attentions: Optional[bool] = None,
470
+ output_hidden_states: Optional[bool] = None,
471
+ return_dict: Optional[bool] = None,
472
+ cache_position: Optional[torch.LongTensor] = None,
473
+ num_logits_to_keep: int = 0,
474
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
475
+ r"""
476
+ Args:
477
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
478
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
479
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
480
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
481
+
482
+ num_logits_to_keep (`int`, *optional*):
483
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
484
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
485
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
486
+
487
+
488
+ Returns:
489
+
490
+ Example:
491
+
492
+ ```python
493
+ >>> from PIL import Image
494
+ >>> import requests
495
+ >>> from transformers import AutoProcessor, MllamaForConditionalGeneration
496
+
497
+ >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
498
+ >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
499
+ >>> processor = AutoProcessor.from_pretrained(checkpoint)
500
+
501
+ >>> prompt = "<|image|>If I had to write a haiku for this one"
502
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
503
+ >>> image = Image.open(requests.get(url, stream=True).raw)
504
+
505
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
506
+
507
+ >>> # Generate
508
+ >>> output = model.generate(**inputs, max_new_tokens=15)
509
+
510
+ >>> prompt_len = inputs.input_ids.shape[-1]
511
+ >>> generated_ids = output[:, prompt_len:]
512
+ >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
513
+ >>> print(generated_text)
514
+ [', it would be:.\\nA stop sign in Chinatown.\\n']
515
+ ```
516
+ """
517
+ output_attentions = output_attentions if output_attentions is not None else self.config.text_config.output_attentions
518
+ output_hidden_states = (
519
+ output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
520
+ )
521
+ return_dict = return_dict if return_dict is not None else self.config.text_config.use_return_dict
522
+
523
+ if (input_ids is None) ^ (inputs_embeds is not None):
524
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
525
+
526
+ if pixel_values is not None and inputs_embeds is not None:
527
+ raise ValueError(
528
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
529
+ )
530
+
531
+ if pixel_values is not None and cross_attention_states is not None:
532
+ raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
533
+
534
+ if pixel_values is not None:
535
+ if aspect_ratio_ids is None:
536
+ raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
537
+ # get vision tokens from vision model
538
+ vision_outputs = self.vision_model(
539
+ pixel_values=pixel_values,
540
+ aspect_ratio_ids=aspect_ratio_ids,
541
+ aspect_ratio_mask=aspect_ratio_mask,
542
+ output_hidden_states=output_hidden_states,
543
+ output_attentions=output_attentions,
544
+ return_dict=return_dict,
545
+ )
546
+ cross_attention_states = vision_outputs[0]
547
+ cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
548
+ -1, cross_attention_states.shape[-2], self.hidden_size
549
+ )
550
+
551
+ if cross_attention_mask is not None:
552
+ cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
553
+ cross_attention_mask,
554
+ num_vision_tokens=self.vision_model.num_patches,
555
+ dtype=self.dtype,
556
+ )
557
+ else:
558
+ full_text_row_masked_out_mask = None
559
+
560
+ if cross_attention_mask is not None and cache_position is not None:
561
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
562
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
563
+
564
+ if inputs_embeds is None:
565
+ inputs_embeds = self.embed_tokens(input_ids=input_ids, audio_features=audio_features)
566
+
567
+ outputs = self.language_model(
568
+ #input_ids=input_ids,
569
+ #audio_features=audio_features,
570
+ attention_mask=attention_mask,
571
+ position_ids=position_ids,
572
+ cross_attention_states=cross_attention_states,
573
+ cross_attention_mask=cross_attention_mask,
574
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
575
+ past_key_values=past_key_values,
576
+ use_cache=use_cache,
577
+ inputs_embeds=inputs_embeds,
578
+ labels=labels,
579
+ output_hidden_states=output_hidden_states,
580
+ output_attentions=output_attentions,
581
+ return_dict=return_dict,
582
+ cache_position=cache_position,
583
+ num_logits_to_keep=num_logits_to_keep,
584
+ )
585
+
586
+ return outputs
587
+
588
+ def prepare_inputs_for_generation(
589
+ self,
590
+ input_ids=None,
591
+ audio_features=None,
592
+ inputs_embeds=None,
593
+ attention_mask=None,
594
+ position_ids=None,
595
+ pixel_values=None,
596
+ aspect_ratio_ids=None,
597
+ aspect_ratio_mask=None,
598
+ cross_attention_mask=None,
599
+ past_key_values=None,
600
+ use_cache=False,
601
+ cache_position=None,
602
+ num_logits_to_keep=None,
603
+ **kwargs,
604
+ ):
605
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
606
+
607
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
608
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
609
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
610
+ if past_key_values is not None:
611
+ if inputs_embeds is not None: # Exception 1
612
+ input_ids = input_ids[:, -cache_position.shape[0] :]
613
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
614
+ input_ids = input_ids[:, cache_position]
615
+
616
+ # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
617
+ if attention_mask is not None and position_ids is None:
618
+ # create position_ids on the fly for batch generation
619
+ position_ids = attention_mask.long().cumsum(-1) - 1
620
+ position_ids.masked_fill_(attention_mask == 0, 1)
621
+ if past_key_values:
622
+ position_ids = position_ids[:, -input_ids.shape[1] :]
623
+
624
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
625
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
626
+
627
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
628
+ if inputs_embeds is not None and cache_position[0] == 0:
629
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
630
+ else:
631
+ # The clone here is for the same reason as for `position_ids`.
632
+ model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
633
+
634
+ if num_logits_to_keep is not None:
635
+ model_inputs["num_logits_to_keep"] = num_logits_to_keep
636
+
637
+ model_inputs.update(
638
+ {
639
+ "position_ids": position_ids,
640
+ "cache_position": cache_position,
641
+ "past_key_values": past_key_values,
642
+ "use_cache": use_cache,
643
+ "attention_mask": attention_mask,
644
+ "cross_attention_mask": cross_attention_mask,
645
+ }
646
+ )
647
+
648
+ # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
649
+ # to compute image hidden states, otherwise they are cached within each cross attn layer
650
+ if cache_position[0] == 0:
651
+ model_inputs["pixel_values"] = pixel_values
652
+ model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
653
+ model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
654
+
655
+ return model_inputs
656
+
657
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
658
+ cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
659
+ model_kwargs = super()._update_model_kwargs_for_generation(
660
+ outputs=outputs,
661
+ model_kwargs=model_kwargs,
662
+ is_encoder_decoder=is_encoder_decoder,
663
+ **kwargs,
664
+ )
665
+
666
+ # add cross-attn mask for new token
667
+ if cross_attention_mask_prev is not None:
668
+ model_kwargs["cross_attention_mask"] = torch.cat(
669
+ [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
670
+ )
671
+ return model_kwargs
672
+
673
+ #AutoModel.register(Llama3Config, Llama3ForConditionalGeneration)
674
+ #transformers.Llama3ForConditionalGeneration = Llama3ForConditionalGeneration
processing_mllama.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Processor class for Mllama."""
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import transformers
23
+ from transformers import AutoProcessor
24
+ from transformers.feature_extraction_utils import BatchFeature
25
+ from transformers.image_utils import ImageInput
26
+ from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, AudioKwargs
27
+ from transformers.tokenization_utils_base import (
28
+ PreTokenizedInput,
29
+ TextInput,
30
+ AudioInput,
31
+ )
32
+
33
+ # TODO: Can we do it that way or its better include as "Copied from ..."
34
+ from transformers.models.mllama.image_processing_mllama import make_list_of_images
35
+ from .audio_processing_mllama import build_audio_tokens
36
+
37
+
38
+ class MllamaImagesKwargs(ImagesKwargs, total=False):
39
+ max_image_tiles: Optional[int]
40
+
41
+ class MllamaProcessorKwargs(ProcessingKwargs, total=False):
42
+ images_kwargs: MllamaImagesKwargs
43
+
44
+ _defaults = {
45
+ "image_kwargs": {
46
+ "max_image_tiles": 4,
47
+ },
48
+ }
49
+
50
+
51
+ def get_cross_attention_token_mask(input_ids: List[int], image_token_id: int) -> List[List[int]]:
52
+ """
53
+ Generate a cross-attention token mask for image tokens in the input sequence.
54
+
55
+ This function identifies the positions of image tokens in the input sequence and creates
56
+ a mask that defines which subsequent tokens each image token should attend to.
57
+
58
+ Args:
59
+ input_ids (List[int]): A list of token ids representing the input sequence.
60
+ image_token_id (int): The id of the token used to represent images in the sequence.
61
+
62
+ Returns:
63
+ List[List[int]]: A list of [start, end] pairs, where each pair represents the range
64
+ of tokens an image token should attend to.
65
+
66
+ Notes:
67
+ - If no image tokens are present, an empty list is returned.
68
+ - For a single image token, it attends to all subsequent tokens until the end of the sequence.
69
+ - For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
70
+ - Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
71
+ """
72
+
73
+ image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
74
+
75
+ if len(image_token_locations) == 0:
76
+ return []
77
+
78
+ # only one image present, unmask until end of sequence
79
+ if len(image_token_locations) == 1:
80
+ return [[image_token_locations[0], -1]]
81
+
82
+ vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
83
+
84
+ # last image will attend to all subsequent text
85
+ vision_masks.append([image_token_locations[-1], len(input_ids)])
86
+
87
+ # if there are two or more consecutive vision tokens,
88
+ # they should all attend to all subsequent
89
+ # text present
90
+ last_mask_end = vision_masks[-1][1]
91
+ for vision_mask in vision_masks[::-1]:
92
+ if vision_mask[0] == vision_mask[1] - 1:
93
+ vision_mask[1] = last_mask_end
94
+ last_mask_end = vision_mask[1]
95
+
96
+ return vision_masks
97
+
98
+
99
+ def convert_sparse_cross_attention_mask_to_dense(
100
+ cross_attention_token_mask: List[List[List[int]]],
101
+ num_tiles: List[List[int]],
102
+ max_num_tiles: int,
103
+ length: int,
104
+ ) -> np.ndarray:
105
+ """
106
+ Convert the cross attention mask indices to a cross attention mask 4D array.
107
+
108
+ This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
109
+ The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
110
+
111
+ Args:
112
+ cross_attention_token_mask (List[List[List[int]]]): A nested list structure where:
113
+ - The outer list represents the batch dimension.
114
+ - The middle list represents different images within each batch item.
115
+ - The inner list contains pairs of integers [start, end] representing token ranges for each image.
116
+ num_tiles (List[List[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
117
+ max_num_tiles (int): The maximum possible number of tiles.
118
+ length (int): The total sequence length of the input.
119
+
120
+ Returns:
121
+ np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
122
+ The array contains `1` where attention is allowed and `0` where it is not.
123
+
124
+ Note:
125
+ - Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
126
+ """
127
+
128
+ batch_size = len(cross_attention_token_mask)
129
+ max_num_images = max([len(masks) for masks in cross_attention_token_mask])
130
+
131
+ cross_attention_mask = np.zeros(
132
+ shape=(batch_size, length, max_num_images, max_num_tiles),
133
+ dtype=np.int64,
134
+ )
135
+
136
+ for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
137
+ for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
138
+ if len(locations) == 2:
139
+ start, end = locations
140
+ end = min(end, length)
141
+ if end == -1:
142
+ end = length
143
+ cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
144
+ return cross_attention_mask
145
+
146
+
147
+ def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
148
+ """
149
+ Builds a string from the input prompt by adding `bos_token` if not already present.
150
+
151
+ Args:
152
+ prompt (`str`):
153
+ The input prompt string.
154
+ bos_token (`str`):
155
+ The beginning of sentence token to be added.
156
+ image_token (`str`):
157
+ The image token used to identify the start of an image sequence.
158
+
159
+ Returns:
160
+ str: The modified prompt string with the `bos_token` added if necessary.
161
+
162
+ Examples:
163
+ >>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
164
+ '<begin_of_text>Hello world'
165
+
166
+ >>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
167
+ '<|image|><begin_of_text>Hello world'
168
+
169
+ >>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
170
+ '<begin_of_text>Hello world'
171
+ """
172
+
173
+ if bos_token in prompt:
174
+ return prompt
175
+
176
+ num_image_tokens_on_start = 0
177
+ while prompt.startswith(image_token):
178
+ prompt = prompt[len(image_token) :]
179
+ num_image_tokens_on_start += 1
180
+
181
+ return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
182
+
183
+
184
+ class MllamaProcessor(ProcessorMixin):
185
+ r"""
186
+ Constructs a Mllama processor which wraps [`MllamaImageProcessor`] and
187
+ [`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and
188
+ tokenizer functionalities. See the [`~MllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
189
+ information.
190
+ The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
191
+ ```python
192
+ from transformers import MllamaProcessor
193
+ from PIL import Image
194
+
195
+ processor = MllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision")
196
+
197
+ processor(
198
+ images=your_pil_image,
199
+ text=["<|image|>If I had to write a haiku for this one"],
200
+ images_kwargs = {"size": {"height": 448, "width": 448}},
201
+ text_kwargs = {"padding": "right"},
202
+ common_kwargs = {"return_tensors": "pt"},
203
+ )
204
+ ```
205
+
206
+ Args:
207
+ image_processor ([`MllamaImageProcessor`]):
208
+ The image processor is a required input.
209
+ tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
210
+ The tokenizer is a required input.
211
+
212
+ """
213
+
214
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
215
+ image_processor_class = "MllamaImageProcessor"
216
+ audio_processor_class = "MllamaAudioFeatureExtractor"
217
+ tokenizer_class = "PreTrainedTokenizerFast"
218
+
219
+ def __init__(self, image_processor, audio_processor, tokenizer):
220
+ self.image_token = "<|image|>"
221
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
222
+ self.audio_token = "<|audio|>"
223
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
224
+ self.python_token = "<|python_tag|>"
225
+ self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
226
+ self.bos_token = tokenizer.bos_token
227
+ self.chat_template = tokenizer.chat_template
228
+ super().__init__(image_processor, audio_processor, tokenizer)
229
+ self.tokenizer.add_tokens([f"<|audio_{i}|>" for i in range(1, 50)])
230
+
231
+
232
+ def __call__(
233
+ self,
234
+ images: Optional[ImageInput] = None,
235
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
236
+ audio: Optional[Union[AudioInput, List[AudioInput]]] = None,
237
+ videos=None,
238
+ **kwargs: Unpack[MllamaProcessorKwargs],
239
+ ) -> BatchFeature:
240
+ """
241
+ Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
242
+ arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
243
+ the text. To prepare the image(s), this method forwards the `images` arguments to
244
+ MllamaImageProcessor's [`~MllamaImageProcessor.__call__`] if `images` is not `None`. Please refer
245
+ to the docstring of the above two methods for more information.
246
+
247
+ Args:
248
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
249
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
250
+ tensor. Both channels-first and channels-last formats are supported.
251
+ text (`str`, `List[str]`, `List[List[str]]`):
252
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
253
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
254
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
255
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
256
+ If set, will return tensors of a particular framework. Acceptable values are:
257
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
258
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
259
+ - `'np'`: Return NumPy `np.ndarray` objects.
260
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
261
+ Returns:
262
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
263
+
264
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
265
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
266
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
267
+ `None`).
268
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
269
+ - **audio_features** -- Audio features extracted using SeamlessM4TFeatureExtractor. Returned when `audio` is not `None`.
270
+ TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
271
+ """
272
+ if text is None:
273
+ raise ValueError("You must specify text.")
274
+
275
+ output_kwargs = self._merge_kwargs(
276
+ MllamaProcessorKwargs,
277
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
278
+ **kwargs,
279
+ )
280
+
281
+ text_kwargs = output_kwargs["text_kwargs"]
282
+ images_kwargs = output_kwargs["images_kwargs"]
283
+ common_kwargs = output_kwargs["common_kwargs"]
284
+
285
+ data = {}
286
+
287
+ if audio is not None:
288
+ audio_features = self.audio_processor(audio)
289
+ data.update(audio_features)
290
+
291
+ if isinstance(text, str):
292
+ text = [text]
293
+ elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
294
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
295
+ n_images_in_text = [t.count(self.image_token) for t in text]
296
+ text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
297
+ _ = text_kwargs.pop("padding_side", None) # hack until padding-side is an accepted kwarg by tokenizers
298
+
299
+ if audio is not None:
300
+ text = build_audio_tokens(text, audio_features, self.audio_token)
301
+
302
+ encoding = self.tokenizer(text, add_special_tokens=False, **text_kwargs)
303
+ if audio is not None:
304
+ beg_audio_id = self.tokenizer.convert_tokens_to_ids("<|audio_1|>")
305
+ idx = torch.where(encoding['input_ids'] >= beg_audio_id)
306
+ encoding['input_ids'][idx] = beg_audio_id - encoding['input_ids'][idx] - 1
307
+ data.update(encoding)
308
+
309
+ n_images_in_images = [0]
310
+ if images is not None:
311
+ images = make_list_of_images(images)
312
+ n_images_in_images = [len(sample) for sample in images]
313
+
314
+ if text is not None:
315
+ if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
316
+ batch_img == 0 for batch_img in n_images_in_text
317
+ ):
318
+ raise ValueError(
319
+ "If a batch of text is provided, there should be either no images or at least one image per sample"
320
+ )
321
+ if sum(n_images_in_images) != sum(n_images_in_text):
322
+ if images is None:
323
+ raise ValueError("No image were provided, but there are image tokens in the prompt")
324
+ else:
325
+ raise ValueError(
326
+ f"The number of image token ({sum(n_images_in_text)}) should be the same as in the number of provided images ({sum(n_images_in_images)})"
327
+ )
328
+
329
+ if images is not None:
330
+ image_features = self.image_processor(images, **images_kwargs)
331
+ num_tiles = image_features.pop("num_tiles")
332
+ data.update(image_features)
333
+
334
+ # Create cross attention mask
335
+ if images is not None and text is not None:
336
+ cross_attention_token_mask = [
337
+ get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
338
+ ]
339
+ cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
340
+ cross_attention_token_mask,
341
+ num_tiles=num_tiles,
342
+ max_num_tiles=self.image_processor.max_image_tiles,
343
+ length=max(len(input_ids) for input_ids in encoding["input_ids"]),
344
+ )
345
+ data["cross_attention_mask"] = cross_attention_mask
346
+
347
+ return_tensors = common_kwargs.pop("return_tensors", None)
348
+ batch_feature = BatchFeature(data=data, tensor_type=return_tensors)
349
+
350
+ return batch_feature
351
+
352
+ def batch_decode(self, *args, **kwargs):
353
+ """
354
+ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
355
+ refer to the docstring of this method for more information.
356
+ """
357
+ return self.tokenizer.batch_decode(*args, **kwargs)
358
+
359
+ def decode(self, *args, **kwargs):
360
+ """
361
+ This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
362
+ the docstring of this method for more information.
363
+ """
364
+ return self.tokenizer.decode(*args, **kwargs)
365
+
366
+ @property
367
+ def model_input_names(self):
368
+ tokenizer_input_names = self.tokenizer.model_input_names
369
+ image_processor_input_names = self.image_processor.model_input_names
370
+ audio_processor_input_names = self.audio_processor.model_input_names
371
+ return list(tokenizer_input_names +
372
+ image_processor_input_names +
373
+ ["cross_attention_mask"] +
374
+ audio_processor_input_names)
375
+
376
+ AutoProcessor.register("MllamaProcessor", MllamaProcessor)
377
+ transformers.MllamaProcessor = MllamaProcessor